From c60d3d6f38a4432192ff210432aac75651dd0545 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 5 Nov 2021 11:46:22 -0500 Subject: [PATCH 1/5] Trainers: split tasks into separate files --- tests/trainers/test_byol.py | 33 +- tests/trainers/test_chesapeake.py | 109 --- .../{test_tasks.py => test_classification.py} | 88 +- tests/trainers/test_landcoverai.py | 91 -- tests/trainers/test_naipchesapeake.py | 91 -- tests/trainers/test_regression.py | 74 ++ tests/trainers/test_segmentation.py | 345 +++++++ tests/trainers/test_sen12ms.py | 94 -- tests/trainers/test_so2sat.py | 115 --- torchgeo/trainers/__init__.py | 18 +- torchgeo/trainers/byol.py | 10 +- torchgeo/trainers/chesapeake.py | 283 ------ .../trainers/{tasks.py => classification.py} | 208 ++-- torchgeo/trainers/landcoverai.py | 259 ----- torchgeo/trainers/naipchesapeake.py | 251 ----- .../trainers/{sen12ms.py => regression.py} | 104 +- torchgeo/trainers/segmentation.py | 921 ++++++++++++++++++ torchgeo/trainers/so2sat.py | 90 -- 18 files changed, 1544 insertions(+), 1640 deletions(-) delete mode 100644 tests/trainers/test_chesapeake.py rename tests/trainers/{test_tasks.py => test_classification.py} (76%) delete mode 100644 tests/trainers/test_landcoverai.py delete mode 100644 tests/trainers/test_naipchesapeake.py create mode 100644 tests/trainers/test_regression.py create mode 100644 tests/trainers/test_segmentation.py delete mode 100644 tests/trainers/test_sen12ms.py delete mode 100644 tests/trainers/test_so2sat.py delete mode 100644 torchgeo/trainers/chesapeake.py rename torchgeo/trainers/{tasks.py => classification.py} (71%) delete mode 100644 torchgeo/trainers/landcoverai.py delete mode 100644 torchgeo/trainers/naipchesapeake.py rename torchgeo/trainers/{sen12ms.py => regression.py} (56%) create mode 100644 torchgeo/trainers/segmentation.py delete mode 100644 torchgeo/trainers/so2sat.py diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 5ef215b0111..ac5e9e2b792 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -19,23 +19,6 @@ from .test_utils import mocked_log -@pytest.fixture(scope="module") -def datamodule() -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=4, - patches_per_tile=2, - batch_size=2, - num_workers=0, - ) - dm.prepare_data() - dm.setup() - return dm - - class TestBYOL: def test_custom_augment_fn(self) -> None: encoder = resnet18() @@ -54,6 +37,22 @@ def test_custom_augment_fn(self) -> None: class TestBYOLTask: + @pytest.fixture(scope="class") + def datamodule(self) -> ChesapeakeCVPRDataModule: + dm = ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=4, + patches_per_tile=2, + batch_size=2, + num_workers=0, + ) + dm.prepare_data() + dm.setup() + return dm + @pytest.fixture(params=["resnet18", "resnet50"]) def config(self, request: SubRequest) -> Dict[str, Any]: task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "byol.yaml")) diff --git a/tests/trainers/test_chesapeake.py b/tests/trainers/test_chesapeake.py deleted file mode 100644 index bb316cca938..00000000000 --- a/tests/trainers/test_chesapeake.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datasets import ChesapeakeCVPRDataModule -from torchgeo.trainers import ChesapeakeCVPRSegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -@pytest.fixture(scope="module", params=[5, 7]) -def class_set(request: SubRequest) -> int: - return cast(int, request.param) - - -@pytest.fixture(scope="module") -def datamodule(class_set: int) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=class_set, - ) - dm.prepare_data() - dm.setup() - return dm - - -class TestChesapeakeCVPRSegmentationTask: - @pytest.fixture( - params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) - ) - def config(self, class_set: int, request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["class_set"] = class_set - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> ChesapeakeCVPRSegmentationTask: - task = ChesapeakeCVPRSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: ChesapeakeCVPRSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_class_set(self, config: Dict[str, Any]) -> None: - config["class_set"] = 6 - error_message = "'class_set' must be either 5 or 7" - with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_classification.py similarity index 76% rename from tests/trainers/test_tasks.py rename to tests/trainers/test_classification.py index ea648370c03..3aa414f1b2c 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_classification.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Optional, cast +from typing import Any, Dict, Generator, Optional, Tuple, cast import pytest import pytorch_lightning as pl @@ -14,11 +14,11 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset, TensorDataset -from torchgeo.datasets import CycloneDataModule +from torchgeo.datasets import So2SatDataModule from torchgeo.trainers import ( ClassificationTask, MultiLabelClassificationTask, - RegressionTask, + So2SatClassificationTask, ) from .test_utils import mocked_log @@ -256,61 +256,103 @@ def test_invalid_loss(self, config: Dict[str, Any]) -> None: MultiLabelClassificationTask(**config) -class TestRegressionTask: - @pytest.fixture(scope="class") - def datamodule(self) -> CycloneDataModule: - root = os.path.join("tests", "data", "cyclone") - seed = 0 - batch_size = 1 +class TestSo2SatClassificationTask: + @pytest.fixture(scope="class", params=[("rgb", 3), ("s2", 10)]) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(scope="class", params=[True, False]) + def datamodule( + self, bands: Tuple[str, int], request: SubRequest + ) -> So2SatDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "so2sat") + batch_size = 2 num_workers = 0 - dm = CycloneDataModule(root, seed, batch_size, num_workers) + dm = So2SatDataModule( + root, batch_size, num_workers, band_set, unsupervised_mode + ) dm.prepare_data() dm.setup() return dm - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "cyclone.yaml") - ) + @pytest.fixture( + params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) + ) + def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: + task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) + task_args["in_channels"] = bands[1] + loss, weights = request.param + task_args["loss"] = loss + task_args["weights"] = weights return task_args @pytest.fixture def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> RegressionTask: - task = RegressionTask(**config) + ) -> So2SatClassificationTask: + task = So2SatClassificationTask(**config) monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] return task - def test_configure_optimizers(self, task: RegressionTask) -> None: + def test_configure_optimizers(self, task: So2SatClassificationTask) -> None: out = task.configure_optimizers() assert "optimizer" in out assert "lr_scheduler" in out def test_training( - self, datamodule: CycloneDataModule, task: RegressionTask + self, datamodule: So2SatDataModule, task: So2SatClassificationTask ) -> None: batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: CycloneDataModule, task: RegressionTask + self, datamodule: So2SatDataModule, task: So2SatClassificationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None: + def test_test( + self, datamodule: So2SatDataModule, task: So2SatClassificationTask + ) -> None: batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) + def test_pretrained(self, checkpoint: str) -> None: + task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["weights"] = checkpoint + with pytest.warns(UserWarning): + So2SatClassificationTask(**task_args) + def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["model"] = "invalid_model" + config["classification_model"] = "invalid_model" error_message = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=error_message): - RegressionTask(**config) + So2SatClassificationTask(**config) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) + + def test_invalid_weights(self, config: Dict[str, Any]) -> None: + config["weights"] = "invalid_weights" + error_message = "Weight type 'invalid_weights' is not valid." + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) + + def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> None: + config["weights"] = checkpoint + config["classification_model"] = "resnet50" + error_message = "Trying to load resnet18 weights into a resnet50" + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py deleted file mode 100644 index bda241bd527..00000000000 --- a/tests/trainers/test_landcoverai.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datasets import LandcoverAIDataModule -from torchgeo.trainers import LandcoverAISegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -@pytest.fixture(scope="module") -def datamodule() -> LandcoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandcoverAIDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - -class TestLandcoverAISegmentationTask: - @pytest.fixture( - params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) - ) - def config(self, request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "landcoverai.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss - task_args["verbose"] = True - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> LandcoverAISegmentationTask: - task = LandcoverAISegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_training( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_configure_optimizers(self, task: LandcoverAISegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - LandcoverAISegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - LandcoverAISegmentationTask(**config) diff --git a/tests/trainers/test_naipchesapeake.py b/tests/trainers/test_naipchesapeake.py deleted file mode 100644 index 4d233f5fb37..00000000000 --- a/tests/trainers/test_naipchesapeake.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datasets import NAIPChesapeakeDataModule -from torchgeo.trainers import NAIPChesapeakeSegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -@pytest.fixture(scope="module") -def datamodule() -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 - dm.prepare_data() - dm.setup() - return dm - - -class TestNAIPChesapeakeSegmentationTask: - @pytest.fixture(params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "ce", "jaccard"])) - def config(self, request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> NAIPChesapeakeSegmentationTask: - task = NAIPChesapeakeSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: NAIPChesapeakeSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - NAIPChesapeakeSegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - NAIPChesapeakeSegmentationTask(**config) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py new file mode 100644 index 00000000000..cfa7e16924b --- /dev/null +++ b/tests/trainers/test_regression.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, cast + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import CycloneDataModule +from torchgeo.trainers import RegressionTask + +from .test_utils import mocked_log + + +class TestRegressionTask: + @pytest.fixture(scope="class") + def datamodule(self) -> CycloneDataModule: + root = os.path.join("tests", "data", "cyclone") + seed = 0 + batch_size = 1 + num_workers = 0 + dm = CycloneDataModule(root, seed, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture + def config(self) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "cyclone.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> RegressionTask: + task = RegressionTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_configure_optimizers(self, task: RegressionTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_training( + self, datamodule: CycloneDataModule, task: RegressionTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: CycloneDataModule, task: RegressionTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + RegressionTask(**config) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py new file mode 100644 index 00000000000..65e21dc89c1 --- /dev/null +++ b/tests/trainers/test_segmentation.py @@ -0,0 +1,345 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, Tuple, cast + +import pytest +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import ( + ChesapeakeCVPRDataModule, + LandcoverAIDataModule, + NAIPChesapeakeDataModule, + SEN12MSDataModule, +) +from torchgeo.trainers import ( + ChesapeakeCVPRSegmentationTask, + LandcoverAISegmentationTask, + NAIPChesapeakeSegmentationTask, + SEN12MSSegmentationTask, +) + +from .test_utils import FakeTrainer, mocked_log + + +class TestChesapeakeCVPRSegmentationTask: + @pytest.fixture(scope="class", params=[5, 7]) + def class_set(self, request: SubRequest) -> int: + return cast(int, request.param) + + @pytest.fixture(scope="class") + def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: + dm = ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=32, + patches_per_tile=2, + batch_size=2, + num_workers=0, + class_set=class_set, + ) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture( + params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) + ) + def config(self, class_set: int, request: SubRequest) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + segmentation_model, loss = request.param + task_args["class_set"] = class_set + task_args["segmentation_model"] = segmentation_model + task_args["loss"] = loss + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> ChesapeakeCVPRSegmentationTask: + task = ChesapeakeCVPRSegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_configure_optimizers(self, task: ChesapeakeCVPRSegmentationTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_training( + self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_class_set(self, config: Dict[str, Any]) -> None: + config["class_set"] = 6 + error_message = "'class_set' must be either 5 or 7" + with pytest.raises(ValueError, match=error_message): + ChesapeakeCVPRSegmentationTask(**config) + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["segmentation_model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + ChesapeakeCVPRSegmentationTask(**config) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + ChesapeakeCVPRSegmentationTask(**config) + + +class TestLandcoverAISegmentationTask: + @pytest.fixture(scope="class") + def datamodule(self) -> LandcoverAIDataModule: + root = os.path.join("tests", "data", "landcoverai") + batch_size = 2 + num_workers = 0 + dm = LandcoverAIDataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture( + params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) + ) + def config(self, request: SubRequest) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "landcoverai.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + segmentation_model, loss = request.param + task_args["segmentation_model"] = segmentation_model + task_args["loss"] = loss + task_args["verbose"] = True + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> LandcoverAISegmentationTask: + task = LandcoverAISegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_training( + self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_configure_optimizers(self, task: LandcoverAISegmentationTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["segmentation_model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + LandcoverAISegmentationTask(**config) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + LandcoverAISegmentationTask(**config) + + +class TestNAIPChesapeakeSegmentationTask: + @pytest.fixture(scope="class") + def datamodule(self) -> NAIPChesapeakeDataModule: + dm = NAIPChesapeakeDataModule( + os.path.join("tests", "data", "naip"), + os.path.join("tests", "data", "chesapeake", "BAYWIDE"), + batch_size=2, + num_workers=0, + ) + dm.patch_size = 32 + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture(params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "ce", "jaccard"])) + def config(self, request: SubRequest) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + segmentation_model, loss = request.param + task_args["segmentation_model"] = segmentation_model + task_args["loss"] = loss + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> NAIPChesapeakeSegmentationTask: + task = NAIPChesapeakeSegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_configure_optimizers(self, task: NAIPChesapeakeSegmentationTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_training( + self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["segmentation_model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + NAIPChesapeakeSegmentationTask(**config) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + NAIPChesapeakeSegmentationTask(**config) + + +class TestSEN12MSSegmentationTask: + @pytest.fixture( + scope="class", + params=[("all", 15), ("s1", 2), ("s2-all", 13), ("s2-reduced", 6)], + ) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(scope="class") + def datamodule(self, bands: Tuple[str, int]) -> SEN12MSDataModule: + root = os.path.join("tests", "data", "sen12ms") + seed = 0 + band_set = bands[0] + batch_size = 1 + num_workers = 0 + dm = SEN12MSDataModule(root, seed, band_set, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture(params=["ce", "jaccard"]) + def config(self, bands: Tuple[str, int], request: SubRequest) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "sen12ms.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["in_channels"] = bands[1] + task_args["loss"] = request.param + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> SEN12MSSegmentationTask: + task = SEN12MSSegmentationTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_configure_optimizers(self, task: SEN12MSSegmentationTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_training( + self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["segmentation_model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + SEN12MSSegmentationTask(**config) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + SEN12MSSegmentationTask(**config) diff --git a/tests/trainers/test_sen12ms.py b/tests/trainers/test_sen12ms.py deleted file mode 100644 index 30e5a9bc972..00000000000 --- a/tests/trainers/test_sen12ms.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, Tuple, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datasets import SEN12MSDataModule -from torchgeo.trainers import SEN12MSSegmentationTask - -from .test_utils import mocked_log - - -@pytest.fixture( - scope="module", params=[("all", 15), ("s1", 2), ("s2-all", 13), ("s2-reduced", 6)] -) -def bands(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - -@pytest.fixture(scope="module") -def datamodule(bands: Tuple[str, int]) -> SEN12MSDataModule: - root = os.path.join("tests", "data", "sen12ms") - seed = 0 - band_set = bands[0] - batch_size = 1 - num_workers = 0 - dm = SEN12MSDataModule(root, seed, band_set, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - -class TestSEN12MSSegmentationTask: - @pytest.fixture(params=["ce", "jaccard"]) - def config(self, bands: Tuple[str, int], request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "sen12ms.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] - task_args["loss"] = request.param - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> SEN12MSSegmentationTask: - task = SEN12MSSegmentationTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: SEN12MSSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - SEN12MSSegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - SEN12MSSegmentationTask(**config) diff --git a/tests/trainers/test_so2sat.py b/tests/trainers/test_so2sat.py deleted file mode 100644 index f828235a8c5..00000000000 --- a/tests/trainers/test_so2sat.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, Tuple, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datasets import So2SatDataModule -from torchgeo.trainers import So2SatClassificationTask - -from .test_utils import mocked_log - - -@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) -def bands(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - -@pytest.fixture(scope="module", params=[True, False]) -def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm - - -class TestSo2SatClassificationTask: - @pytest.fixture( - params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) - ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] - loss, weights = request.param - task_args["loss"] = loss - task_args["weights"] = weights - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> So2SatClassificationTask: - task = So2SatClassificationTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: So2SatClassificationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_pretrained(self, checkpoint: str) -> None: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["weights"] = checkpoint - with pytest.warns(UserWarning): - So2SatClassificationTask(**task_args) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["classification_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_weights(self, config: Dict[str, Any]) -> None: - config["weights"] = "invalid_weights" - error_message = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> None: - config["weights"] = checkpoint - config["classification_model"] = "resnet50" - error_message = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 1234e0b3019..ad398a1e7cd 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -4,12 +4,18 @@ """TorchGeo trainers.""" from .byol import BYOLTask -from .chesapeake import ChesapeakeCVPRSegmentationTask -from .landcoverai import LandcoverAISegmentationTask -from .naipchesapeake import NAIPChesapeakeSegmentationTask -from .sen12ms import SEN12MSSegmentationTask -from .so2sat import So2SatClassificationTask -from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask +from .classification import ( + ClassificationTask, + MultiLabelClassificationTask, + So2SatClassificationTask, +) +from .regression import RegressionTask +from .segmentation import ( + ChesapeakeCVPRSegmentationTask, + LandcoverAISegmentationTask, + NAIPChesapeakeSegmentationTask, + SEN12MSSegmentationTask, +) __all__ = ( "BYOLTask", diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index fc9601bce22..23690cd298e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Trainer task for BYOL.""" +"""BYOL tasks.""" + import random from typing import Any, Callable, Dict, Optional, Tuple, Union, cast @@ -18,12 +19,9 @@ from torchvision.models import resnet18 from torchvision.models.resnet import resnet50 +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 Module.__module__ = "torch.nn" -Sequential.__module__ = "torch.nn" -Linear.__module__ = "torch.nn" -ReLU.__module__ = "torch.nn" -BatchNorm1d.__module__ = "torch.nn" -Conv2d.__module__ = "torch.nn" def normalized_mse(x: Tensor, y: Tensor) -> Tensor: diff --git a/torchgeo/trainers/chesapeake.py b/torchgeo/trainers/chesapeake.py deleted file mode 100644 index 47c0e11b902..00000000000 --- a/torchgeo/trainers/chesapeake.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Trainers for the Chesapeake datasets.""" - -from typing import Any, Dict, cast - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import segmentation_models_pytorch as smp -import torch -import torch.nn as nn -from pytorch_lightning.core.lightning import LightningModule -from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] -from torchmetrics import Accuracy, IoU, MetricCollection - -from ..datasets import Chesapeake7 -from ..models import FCN - -# TODO: move the color maps to a dataset object -CMAP_7 = matplotlib.colors.ListedColormap( - [np.array(Chesapeake7.cmap[i]) / 255.0 for i in range(7)] -) -CMAP_5 = matplotlib.colors.ListedColormap( - np.array( - [ - (0, 0, 0, 0), - (0, 197, 255, 255), - (38, 115, 0, 255), - (163, 255, 115, 255), - (156, 156, 156, 255), - ] - ) - / 255.0 -) - - -class ChesapeakeCVPRSegmentationTask(LightningModule): - """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ - - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - if self.hparams["class_set"] not in [5, 7]: - raise ValueError("'class_set' must be either 5 or 7") - num_classes = self.hparams["class_set"] - classes = range(1, self.hparams["class_set"]) - - if self.hparams["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=4, - classes=num_classes, - ) - elif self.hparams["segmentation_model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=4, - classes=num_classes, - ) - elif self.hparams["segmentation_model"] == "fcn": - self.model = FCN(in_channels=4, classes=num_classes, num_filters=256) - else: - raise ValueError( - f"Model type '{self.hparams['segmentation_model']}' is not valid." - ) - - if self.hparams["loss"] == "ce": - self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] - ignore_index=0 - ) - elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=classes) - elif self.hparams["loss"] == "focal": - self.loss = smp.losses.FocalLoss( - "multiclass", ignore_index=0, normalized=True - ) - else: - raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - loss: Name of the loss function - - Raises: - ValueError: if kwargs arguments are invalid - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - - self.config_task() - - self.train_metrics = MetricCollection( - [ - Accuracy(num_classes=self.hparams["class_set"], ignore_index=0), - IoU(num_classes=self.hparams["class_set"], ignore_index=0), - ], - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: tensor of data to run through the model - - Returns: - output from the model - """ - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10: - cmap = None - if self.hparams["class_set"] == 5: - cmap = CMAP_5 - else: - cmap = CMAP_7 - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img[:, :, :3]) - axs[0].axis("off") - axs[1].imshow( - mask, - vmin=0, - vmax=self.hparams["class_set"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[1].axis("off") - axs[2].imshow( - pred, - vmin=0, - vmax=self.hparams["class_set"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[2].axis("off") - plt.tight_layout() - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - }, - } diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/classification.py similarity index 71% rename from torchgeo/trainers/tasks.py rename to torchgeo/trainers/classification.py index d45cd1d4070..e8a00459ee4 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/classification.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Base trainer tasks.""" +"""Classification tasks.""" import os from typing import Any, Dict, cast @@ -10,20 +10,12 @@ import timm import torch import torch.nn as nn -import torch.nn.functional as F +import torchvision.models from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor from torch.nn.modules import Conv2d, Linear from torch.optim.lr_scheduler import ReduceLROnPlateau -from torchmetrics import ( - Accuracy, - FBeta, - IoU, - MeanAbsoluteError, - MeanSquaredError, - MetricCollection, -) -from torchvision import models +from torchmetrics import Accuracy, FBeta, IoU, MetricCollection from . import utils @@ -364,143 +356,73 @@ def test_step( # type: ignore[override] self.test_metrics(y_hat_hard, y) -class RegressionTask(pl.LightningModule): - """LightningModule for training models on regression datasets.""" +# TODO: move this functionality into ClassificationTask and remove this class +class So2SatClassificationTask(ClassificationTask): + """LightningModule for training models on the So2Sat Dataset.""" - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" - if self.hparams["model"] == "resnet18": - self.model = models.resnet18(pretrained=True) - in_features = self.model.fc.in_features - self.model.fc = nn.Linear( # type: ignore[attr-defined] - in_features, out_features=1 - ) - else: - raise ValueError(f"Model type '{self.hparams['model']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize a new LightningModule for training simple regression models. - - Keyword Args: - model: Name of the model to use - learning_rate: Initial learning rate to use in the optimizer - learning_rate_schedule_patience: Patience parameter for the LR scheduler - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - self.config_task() - - self.train_metrics = MetricCollection( - {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model.""" - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step with an MSE loss. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["label"].view(-1, 1) - y_hat = self.forward(x) - - loss = F.mse_loss(y_hat, y) - - self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y) - - return loss - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch-level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["label"].view(-1, 1) - y_hat = self.forward(x) - - loss = F.mse_loss(y_hat, y) - self.log("val_loss", loss) - self.val_metrics(y_hat, y) - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() + def config_model(self) -> None: + """Configures the model based on kwargs parameters passed to the constructor.""" + in_channels = self.hparams["in_channels"] - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step. + pretrained = False + if not os.path.exists(self.hparams["weights"]): + if self.hparams["weights"] == "imagenet": + pretrained = True + elif self.hparams["weights"] == "random": + pretrained = False + else: + raise ValueError( + f"Weight type '{self.hparams['weights']}' is not valid." + ) - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["label"].view(-1, 1) - y_hat = self.forward(x) + # Create the model + if "resnet" in self.hparams["classification_model"]: + self.model = getattr( + torchvision.models.resnet, self.hparams["classification_model"] + )(pretrained=pretrained) + in_features = self.model.fc.in_features + self.model.fc = Linear( + in_features, out_features=self.hparams["num_classes"] + ) - loss = F.mse_loss(y_hat, y) - self.log("test_loss", loss) - self.test_metrics(y_hat, y) + # Update first layer + if in_channels != 3: + w_old = None + if pretrained: + w_old = torch.clone( # type: ignore[attr-defined] + self.model.conv1.weight + ).detach() + # Create the new layer + self.model.conv1 = Conv2d( + in_channels, 64, kernel_size=7, stride=1, padding=2, bias=False + ) + nn.init.kaiming_normal_( # type: ignore[no-untyped-call] + self.model.conv1.weight, mode="fan_out", nonlinearity="relu" + ) - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. + # We copy over the pretrained RGB weights + if pretrained: + w_new = torch.clone( # type: ignore[attr-defined] + self.model.conv1.weight + ).detach() + w_new[:, :3, :, :] = w_old + self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 + w_new + ) + else: + raise ValueError( + f"Model type '{self.hparams['classification_model']}' is not valid." + ) - Args: - outputs: list of items returned by test_step - """ - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() + # Load pretrained weights checkpoint weights + if "resnet" in self.hparams["classification_model"]: + if os.path.exists(self.hparams["weights"]): + name, state_dict = utils.extract_encoder(self.hparams["weights"]) - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. + if self.hparams["classification_model"] != name: + raise ValueError( + f"Trying to load {name} weights into a " + f"{self.hparams['classification_model']}" + ) - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - }, - } + self.model = utils.load_state_dict(self.model, state_dict) diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py deleted file mode 100644 index 245939b92e9..00000000000 --- a/torchgeo/trainers/landcoverai.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Landcover.ai trainer.""" - -from typing import Any, Dict, cast - -import kornia.augmentation as K -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import segmentation_models_pytorch as smp -import torch -import torch.nn as nn -from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] -from torchmetrics import Accuracy, IoU, MetricCollection - -from ..models import FCN - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -class LandcoverAISegmentationTask(pl.LightningModule): - """LightningModule for training models on the Landcover.AI Dataset. - - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ - - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - if self.hparams["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=3, - classes=6, - ) - elif self.hparams["segmentation_model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=3, - classes=6, - ) - elif self.hparams["segmentation_model"] == "fcn": - self.model = FCN(in_channels=3, classes=6, num_filters=256) - else: - raise ValueError( - f"Model type '{self.hparams['segmentation_model']}' is not valid." - ) - - if self.hparams["loss"] == "ce": - self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] - ignore_index=0 - ) - elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=range(1, 6)) - elif self.hparams["loss"] == "focal": - self.loss = smp.losses.FocalLoss( - "multiclass", ignore_index=0, normalized=True - ) - else: - raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - encoder_output_stride: The output stride parameter in DeepLabV3+ models - loss: Name of the loss function - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - - self.config_task() - - self.train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["input", "mask"], - ) - - self.train_metrics = MetricCollection( - [ - Accuracy(num_classes=6, ignore_index=0), - IoU(num_classes=6, ignore_index=0), - ], - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: input image - - Returns: - prediction - """ - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - with torch.no_grad(): - x, y = self.train_augmentations(x, y) - y = y.long().squeeze() - - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"].long().squeeze() - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10 and self.hparams["verbose"]: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - x[0].cpu().numpy(), 0, 3 - ) - mask = y[0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=5) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=5) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"].long().squeeze() - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - }, - } diff --git a/torchgeo/trainers/naipchesapeake.py b/torchgeo/trainers/naipchesapeake.py deleted file mode 100644 index f0c2f9e585c..00000000000 --- a/torchgeo/trainers/naipchesapeake.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""NAIP + Chesapeake trainer.""" - -from typing import Any, Dict, cast - -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import segmentation_models_pytorch as smp -import torch -import torch.nn as nn -from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] -from torchmetrics import Accuracy, IoU - -from ..models import FCN - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -class NAIPChesapeakeSegmentationTask(pl.LightningModule): - """LightningModule for training models on the NAIP and Chesapeake datasets. - - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ - - in_channels = 4 - classes = 13 - # TODO: tune this hyperparam - num_filters = 64 - - def config_task(self, kwargs: Any) -> None: - """Configures the task based on kwargs parameters.""" - if kwargs["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=kwargs["encoder_name"], - encoder_weights=kwargs["encoder_weights"], - in_channels=self.in_channels, - classes=self.classes, - ) - elif kwargs["segmentation_model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=kwargs["encoder_name"], - encoder_weights=kwargs["encoder_weights"], - encoder_output_stride=kwargs["encoder_output_stride"], - in_channels=self.in_channels, - classes=self.classes, - ) - elif kwargs["segmentation_model"] == "fcn": - self.model = FCN(self.in_channels, self.classes, self.num_filters) - else: - raise ValueError( - f"Model type '{kwargs['segmentation_model']}' is not valid." - ) - - if kwargs["loss"] == "ce": - self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] - elif kwargs["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass") - else: - raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - encoder_output_stride: The output stride parameter in DeepLabV3+ models - loss: Name of the loss function - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - - self.config_task(kwargs) - - self.train_accuracy = Accuracy() - self.val_accuracy = Accuracy() - self.test_accuracy = Accuracy() - - self.train_iou = IoU(num_classes=self.classes) - self.val_iou = IoU(num_classes=self.classes) - self.test_iou = IoU(num_classes=self.classes) - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: input image - - Returns: - prediction - """ - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: current batch - batch_idx: index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_accuracy(y_hat_hard, y) - self.train_iou(y_hat_hard, y) - - return cast(Tensor, loss) - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log("train_acc", self.train_accuracy.compute()) - self.log("train_iou", self.train_iou.compute()) - self.train_accuracy.reset() - self.train_iou.reset() - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Args: - batch: current batch - batch_idx: index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("val_loss", loss) - self.val_accuracy(y_hat_hard, y) - self.val_iou(y_hat_hard, y) - - if batch_idx < 10: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=4) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=4) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log("val_acc", self.val_accuracy.compute()) - self.log("val_iou", self.val_iou.compute()) - self.val_accuracy.reset() - self.val_iou.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: current batch - batch_idx: index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss) - self.test_accuracy(y_hat_hard, y) - self.test_iou(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log("test_acc", self.test_accuracy.compute()) - self.log("test_iou", self.test_iou.compute()) - self.test_accuracy.reset() - self.test_iou.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - "verbose": True, - }, - } diff --git a/torchgeo/trainers/sen12ms.py b/torchgeo/trainers/regression.py similarity index 56% rename from torchgeo/trainers/sen12ms.py rename to torchgeo/trainers/regression.py index 9af75eb0456..47d138ae228 100644 --- a/torchgeo/trainers/sen12ms.py +++ b/torchgeo/trainers/regression.py @@ -1,81 +1,67 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""SEN12MS trainer.""" +"""Regression tasks.""" -from typing import Any, Dict, cast +from typing import Any, Dict import pytorch_lightning as pl -import segmentation_models_pytorch as smp import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor +from torch.nn.modules import Conv2d, Linear from torch.optim.lr_scheduler import ReduceLROnPlateau -from torchmetrics import Accuracy, MetricCollection +from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchvision import models +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +Conv2d.__module__ = "nn.Conv2d" +Linear.__module__ = "nn.Linear" -class SEN12MSSegmentationTask(pl.LightningModule): - """LightningModule for training models on the SEN12MS Dataset. - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ +class RegressionTask(pl.LightningModule): + """LightningModule for training models on regression datasets.""" def config_task(self) -> None: """Configures the task based on kwargs parameters.""" - if self.hparams["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=self.hparams["in_channels"], - classes=11, + if self.hparams["model"] == "resnet18": + self.model = models.resnet18(pretrained=True) + in_features = self.model.fc.in_features + self.model.fc = nn.Linear( # type: ignore[attr-defined] + in_features, out_features=1 ) else: - raise ValueError( - f"Model type '{self.hparams['segmentation_model']}' is not valid." - ) - - if self.hparams["loss"] == "ce": - self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] - elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass") - else: - raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + raise ValueError(f"Model type '{self.hparams['model']}' is not valid.") def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. + """Initialize a new LightningModule for training simple regression models. Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - loss: Name of the loss function + model: Name of the model to use + learning_rate: Initial learning rate to use in the optimizer + learning_rate_schedule_patience: Patience parameter for the LR scheduler """ super().__init__() self.save_hyperparameters() # creates `self.hparams` from kwargs - self.config_task() - self.train_metrics = MetricCollection([Accuracy()], prefix="train_") + self.train_metrics = MetricCollection( + {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, + prefix="train_", + ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_") def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: input image - - Returns: - prediction - """ + """Forward pass of the model.""" return self.model(x) def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> Tensor: - """Training step. + """Training step with an MSE loss. Args: batch: Current batch @@ -85,19 +71,18 @@ def training_step( # type: ignore[override] training loss """ x = batch["image"] - y = batch["mask"] + y = batch["label"].view(-1, 1) y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - loss = self.loss(y_hat, y) + loss = F.mse_loss(y_hat, y) - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) + self.log("train_loss", loss) # logging to TensorBoard + self.train_metrics(y_hat, y) - return cast(Tensor, loss) + return loss def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. + """Logs epoch-level training metrics. Args: outputs: list of items returned by training_step @@ -115,14 +100,12 @@ def validation_step( # type: ignore[override] batch_idx: Index of current batch """ x = batch["image"] - y = batch["mask"] + y = batch["label"].view(-1, 1) y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) + loss = F.mse_loss(y_hat, y) + self.log("val_loss", loss) + self.val_metrics(y_hat, y) def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. @@ -143,15 +126,12 @@ def test_step( # type: ignore[override] batch_idx: Index of current batch """ x = batch["image"] - y = batch["mask"] + y = batch["label"].view(-1, 1) y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) + loss = F.mse_loss(y_hat, y) + self.log("test_loss", loss) + self.test_metrics(y_hat, y) def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py new file mode 100644 index 00000000000..4637950b381 --- /dev/null +++ b/torchgeo/trainers/segmentation.py @@ -0,0 +1,921 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Segmentation tasks.""" + +from typing import Any, Dict, cast + +import kornia.augmentation as K +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import segmentation_models_pytorch as smp +import torch +import torch.nn as nn +from pytorch_lightning.core.lightning import LightningModule +from torch import Tensor +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] +from torchmetrics import Accuracy, IoU, MetricCollection + +from ..datasets import Chesapeake7 +from ..models import FCN + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + +# TODO: move the color maps to a dataset object +CMAP_7 = matplotlib.colors.ListedColormap( + [np.array(Chesapeake7.cmap[i]) / 255.0 for i in range(7)] +) +CMAP_5 = matplotlib.colors.ListedColormap( + np.array( + [ + (0, 0, 0, 0), + (0, 197, 255, 255), + (38, 115, 0, 255), + (163, 255, 115, 255), + (156, 156, 156, 255), + ] + ) + / 255.0 +) + + +# TODO: combine all of these classes into a single SemanticSegmentationTask +class ChesapeakeCVPRSegmentationTask(LightningModule): + """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. + + This allows using arbitrary models and losses from the + ``pytorch_segmentation_models`` package. + """ + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + if self.hparams["class_set"] not in [5, 7]: + raise ValueError("'class_set' must be either 5 or 7") + num_classes = self.hparams["class_set"] + classes = range(1, self.hparams["class_set"]) + + if self.hparams["segmentation_model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hparams["encoder_name"], + encoder_weights=self.hparams["encoder_weights"], + in_channels=4, + classes=num_classes, + ) + elif self.hparams["segmentation_model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=self.hparams["encoder_name"], + encoder_weights=self.hparams["encoder_weights"], + in_channels=4, + classes=num_classes, + ) + elif self.hparams["segmentation_model"] == "fcn": + self.model = FCN(in_channels=4, classes=num_classes, num_filters=256) + else: + raise ValueError( + f"Model type '{self.hparams['segmentation_model']}' is not valid." + ) + + if self.hparams["loss"] == "ce": + self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] + ignore_index=0 + ) + elif self.hparams["loss"] == "jaccard": + self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=classes) + elif self.hparams["loss"] == "focal": + self.loss = smp.losses.FocalLoss( + "multiclass", ignore_index=0, normalized=True + ) + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + segmentation_model: Name of the segmentation model type to use + encoder_name: Name of the encoder model backbone to use + encoder_weights: None or "imagenet" to use imagenet pretrained weights in + the encoder model + loss: Name of the loss function + + Raises: + ValueError: if kwargs arguments are invalid + """ + super().__init__() + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_metrics = MetricCollection( + [ + Accuracy(num_classes=self.hparams["class_set"], ignore_index=0), + IoU(num_classes=self.hparams["class_set"], ignore_index=0), + ], + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def forward(self, x: Tensor) -> Any: # type: ignore[override] + """Forward pass of the model. + + Args: + x: tensor of data to run through the model + + Returns: + output from the model + """ + return self.model(x) + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step - reports average accuracy and average IoU. + + Args: + batch: Current batch + batch_idx: Index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def training_epoch_end(self, outputs: Any) -> None: + """Logs epoch level training metrics. + + Args: + outputs: list of items returned by training_step + """ + self.log_dict(self.train_metrics.compute()) + self.train_metrics.reset() + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Logs the first 10 validation samples to tensorboard as images with 3 subplots + showing the image, mask, and predictions. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + if batch_idx < 10: + cmap = None + if self.hparams["class_set"] == 5: + cmap = CMAP_5 + else: + cmap = CMAP_7 + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + batch["image"][0].cpu().numpy(), 0, 3 + ) + mask = batch["mask"][0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img[:, :, :3]) + axs[0].axis("off") + axs[1].imshow( + mask, + vmin=0, + vmax=self.hparams["class_set"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[1].axis("off") + axs[2].imshow( + pred, + vmin=0, + vmax=self.hparams["class_set"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[2].axis("off") + plt.tight_layout() + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + + def validation_epoch_end(self, outputs: Any) -> None: + """Logs epoch level validation metrics. + + Args: + outputs: list of items returned by validation_step + """ + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step identical to the validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) + + def test_epoch_end(self, outputs: Any) -> None: + """Logs epoch level test metrics. + + Args: + outputs: list of items returned by test_step + """ + self.log_dict(self.test_metrics.compute()) + self.test_metrics.reset() + + def configure_optimizers(self) -> Dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + a "lr dict" according to the pytorch lightning documentation -- + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.hparams["learning_rate"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer, patience=self.hparams["learning_rate_schedule_patience"] + ), + "monitor": "val_loss", + }, + } + + +class LandcoverAISegmentationTask(pl.LightningModule): + """LightningModule for training models on the Landcover.AI Dataset. + + This allows using arbitrary models and losses from the + ``pytorch_segmentation_models`` package. + """ + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + if self.hparams["segmentation_model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hparams["encoder_name"], + encoder_weights=self.hparams["encoder_weights"], + in_channels=3, + classes=6, + ) + elif self.hparams["segmentation_model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=self.hparams["encoder_name"], + encoder_weights=self.hparams["encoder_weights"], + in_channels=3, + classes=6, + ) + elif self.hparams["segmentation_model"] == "fcn": + self.model = FCN(in_channels=3, classes=6, num_filters=256) + else: + raise ValueError( + f"Model type '{self.hparams['segmentation_model']}' is not valid." + ) + + if self.hparams["loss"] == "ce": + self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] + ignore_index=0 + ) + elif self.hparams["loss"] == "jaccard": + self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=range(1, 6)) + elif self.hparams["loss"] == "focal": + self.loss = smp.losses.FocalLoss( + "multiclass", ignore_index=0, normalized=True + ) + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + segmentation_model: Name of the segmentation model type to use + encoder_name: Name of the encoder model backbone to use + encoder_weights: None or "imagenet" to use imagenet pretrained weights in + the encoder model + encoder_output_stride: The output stride parameter in DeepLabV3+ models + loss: Name of the loss function + """ + super().__init__() + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_augmentations = K.AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + data_keys=["input", "mask"], + ) + + self.train_metrics = MetricCollection( + [ + Accuracy(num_classes=6, ignore_index=0), + IoU(num_classes=6, ignore_index=0), + ], + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def forward(self, x: Tensor) -> Any: # type: ignore[override] + """Forward pass of the model. + + Args: + x: input image + + Returns: + prediction + """ + return self.model(x) + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step - reports average accuracy and average IoU. + + Args: + batch: Current batch + batch_idx: Index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["mask"] + with torch.no_grad(): + x, y = self.train_augmentations(x, y) + y = y.long().squeeze() + + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def training_epoch_end(self, outputs: Any) -> None: + """Logs epoch level training metrics. + + Args: + outputs: list of items returned by training_step + """ + self.log_dict(self.train_metrics.compute()) + self.train_metrics.reset() + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Logs the first 10 validation samples to tensorboard as images with 3 subplots + showing the image, mask, and predictions. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"].long().squeeze() + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + if batch_idx < 10 and self.hparams["verbose"]: + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + x[0].cpu().numpy(), 0, 3 + ) + mask = y[0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=5) + axs[1].axis("off") + axs[2].imshow(pred, vmin=0, vmax=5) + axs[2].axis("off") + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + + plt.close() + + def validation_epoch_end(self, outputs: Any) -> None: + """Logs epoch level validation metrics. + + Args: + outputs: list of items returned by validation_step + """ + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step identical to the validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"].long().squeeze() + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) + + def test_epoch_end(self, outputs: Any) -> None: + """Logs epoch level test metrics. + + Args: + outputs: list of items returned by test_step + """ + self.log_dict(self.test_metrics.compute()) + self.test_metrics.reset() + + def configure_optimizers(self) -> Dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + a "lr dict" according to the pytorch lightning documentation -- + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.hparams["learning_rate"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer, patience=self.hparams["learning_rate_schedule_patience"] + ), + "monitor": "val_loss", + }, + } + + +class NAIPChesapeakeSegmentationTask(pl.LightningModule): + """LightningModule for training models on the NAIP and Chesapeake datasets. + + This allows using arbitrary models and losses from the + ``pytorch_segmentation_models`` package. + """ + + in_channels = 4 + classes = 13 + # TODO: tune this hyperparam + num_filters = 64 + + def config_task(self, kwargs: Any) -> None: + """Configures the task based on kwargs parameters.""" + if kwargs["segmentation_model"] == "unet": + self.model = smp.Unet( + encoder_name=kwargs["encoder_name"], + encoder_weights=kwargs["encoder_weights"], + in_channels=self.in_channels, + classes=self.classes, + ) + elif kwargs["segmentation_model"] == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=kwargs["encoder_name"], + encoder_weights=kwargs["encoder_weights"], + encoder_output_stride=kwargs["encoder_output_stride"], + in_channels=self.in_channels, + classes=self.classes, + ) + elif kwargs["segmentation_model"] == "fcn": + self.model = FCN(self.in_channels, self.classes, self.num_filters) + else: + raise ValueError( + f"Model type '{kwargs['segmentation_model']}' is not valid." + ) + + if kwargs["loss"] == "ce": + self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] + elif kwargs["loss"] == "jaccard": + self.loss = smp.losses.JaccardLoss(mode="multiclass") + else: + raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + segmentation_model: Name of the segmentation model type to use + encoder_name: Name of the encoder model backbone to use + encoder_weights: None or "imagenet" to use imagenet pretrained weights in + the encoder model + encoder_output_stride: The output stride parameter in DeepLabV3+ models + loss: Name of the loss function + """ + super().__init__() + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task(kwargs) + + self.train_accuracy = Accuracy() + self.val_accuracy = Accuracy() + self.test_accuracy = Accuracy() + + self.train_iou = IoU(num_classes=self.classes) + self.val_iou = IoU(num_classes=self.classes) + self.test_iou = IoU(num_classes=self.classes) + + def forward(self, x: Tensor) -> Any: # type: ignore[override] + """Forward pass of the model. + + Args: + x: input image + + Returns: + prediction + """ + return self.model(x) + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step - reports average accuracy and average IoU. + + Args: + batch: current batch + batch_idx: index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_accuracy(y_hat_hard, y) + self.train_iou(y_hat_hard, y) + + return cast(Tensor, loss) + + def training_epoch_end(self, outputs: Any) -> None: + """Logs epoch level training metrics. + + Args: + outputs: list of items returned by training_step + """ + self.log("train_acc", self.train_accuracy.compute()) + self.log("train_iou", self.train_iou.compute()) + self.train_accuracy.reset() + self.train_iou.reset() + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Args: + batch: current batch + batch_idx: index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("val_loss", loss) + self.val_accuracy(y_hat_hard, y) + self.val_iou(y_hat_hard, y) + + if batch_idx < 10: + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + batch["image"][0].cpu().numpy(), 0, 3 + ) + mask = batch["mask"][0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4) + axs[1].axis("off") + axs[2].imshow(pred, vmin=0, vmax=4) + axs[2].axis("off") + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + + plt.close() + + def validation_epoch_end(self, outputs: Any) -> None: + """Logs epoch level validation metrics. + + Args: + outputs: list of items returned by validation_step + """ + self.log("val_acc", self.val_accuracy.compute()) + self.log("val_iou", self.val_iou.compute()) + self.val_accuracy.reset() + self.val_iou.reset() + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step identical to the validation step. + + Args: + batch: current batch + batch_idx: index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss) + self.test_accuracy(y_hat_hard, y) + self.test_iou(y_hat_hard, y) + + def test_epoch_end(self, outputs: Any) -> None: + """Logs epoch level test metrics. + + Args: + outputs: list of items returned by test_step + """ + self.log("test_acc", self.test_accuracy.compute()) + self.log("test_iou", self.test_iou.compute()) + self.test_accuracy.reset() + self.test_iou.reset() + + def configure_optimizers(self) -> Dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + a "lr dict" according to the pytorch lightning documentation -- + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + optimizer = torch.optim.AdamW( + self.model.parameters(), lr=self.hparams["learning_rate"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer, patience=self.hparams["learning_rate_schedule_patience"] + ), + "monitor": "val_loss", + "verbose": True, + }, + } + + +class SEN12MSSegmentationTask(pl.LightningModule): + """LightningModule for training models on the SEN12MS Dataset. + + This allows using arbitrary models and losses from the + ``pytorch_segmentation_models`` package. + """ + + def config_task(self) -> None: + """Configures the task based on kwargs parameters.""" + if self.hparams["segmentation_model"] == "unet": + self.model = smp.Unet( + encoder_name=self.hparams["encoder_name"], + encoder_weights=self.hparams["encoder_weights"], + in_channels=self.hparams["in_channels"], + classes=11, + ) + else: + raise ValueError( + f"Model type '{self.hparams['segmentation_model']}' is not valid." + ) + + if self.hparams["loss"] == "ce": + self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] + elif self.hparams["loss"] == "jaccard": + self.loss = smp.losses.JaccardLoss(mode="multiclass") + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + segmentation_model: Name of the segmentation model type to use + encoder_name: Name of the encoder model backbone to use + encoder_weights: None or "imagenet" to use imagenet pretrained weights in + the encoder model + loss: Name of the loss function + """ + super().__init__() + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_metrics = MetricCollection([Accuracy()], prefix="train_") + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def forward(self, x: Tensor) -> Any: # type: ignore[override] + """Forward pass of the model. + + Args: + x: input image + + Returns: + prediction + """ + return self.model(x) + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step. + + Args: + batch: Current batch + batch_idx: Index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def training_epoch_end(self, outputs: Any) -> None: + """Logs epoch level training metrics. + + Args: + outputs: list of items returned by training_step + """ + self.log_dict(self.train_metrics.compute()) + self.train_metrics.reset() + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + def validation_epoch_end(self, outputs: Any) -> None: + """Logs epoch level validation metrics. + + Args: + outputs: list of items returned by validation_step + """ + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) + + def test_epoch_end(self, outputs: Any) -> None: + """Logs epoch level test metrics. + + Args: + outputs: list of items returned by test_step + """ + self.log_dict(self.test_metrics.compute()) + self.test_metrics.reset() + + def configure_optimizers(self) -> Dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + a "lr dict" according to the pytorch lightning documentation -- + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + optimizer = torch.optim.AdamW( + self.model.parameters(), lr=self.hparams["learning_rate"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer, patience=self.hparams["learning_rate_schedule_patience"] + ), + "monitor": "val_loss", + }, + } diff --git a/torchgeo/trainers/so2sat.py b/torchgeo/trainers/so2sat.py deleted file mode 100644 index 1384239b58f..00000000000 --- a/torchgeo/trainers/so2sat.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""So2Sat trainer.""" - -import os - -import torch -import torch.nn as nn -import torchvision.models -from torch.nn.modules import Conv2d, Linear - -from . import utils -from .tasks import ClassificationTask - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -Conv2d.__module__ = "nn.Conv2d" -Linear.__module__ = "nn.Linear" - - -class So2SatClassificationTask(ClassificationTask): - """LightningModule for training models on the So2Sat Dataset.""" - - def config_model(self) -> None: - """Configures the model based on kwargs parameters passed to the constructor.""" - in_channels = self.hparams["in_channels"] - - pretrained = False - if not os.path.exists(self.hparams["weights"]): - if self.hparams["weights"] == "imagenet": - pretrained = True - elif self.hparams["weights"] == "random": - pretrained = False - else: - raise ValueError( - f"Weight type '{self.hparams['weights']}' is not valid." - ) - - # Create the model - if "resnet" in self.hparams["classification_model"]: - self.model = getattr( - torchvision.models.resnet, self.hparams["classification_model"] - )(pretrained=pretrained) - in_features = self.model.fc.in_features - self.model.fc = Linear( - in_features, out_features=self.hparams["num_classes"] - ) - - # Update first layer - if in_channels != 3: - w_old = None - if pretrained: - w_old = torch.clone( # type: ignore[attr-defined] - self.model.conv1.weight - ).detach() - # Create the new layer - self.model.conv1 = Conv2d( - in_channels, 64, kernel_size=7, stride=1, padding=2, bias=False - ) - nn.init.kaiming_normal_( # type: ignore[no-untyped-call] - self.model.conv1.weight, mode="fan_out", nonlinearity="relu" - ) - - # We copy over the pretrained RGB weights - if pretrained: - w_new = torch.clone( # type: ignore[attr-defined] - self.model.conv1.weight - ).detach() - w_new[:, :3, :, :] = w_old - self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 - w_new - ) - else: - raise ValueError( - f"Model type '{self.hparams['classification_model']}' is not valid." - ) - - # Load pretrained weights checkpoint weights - if "resnet" in self.hparams["classification_model"]: - if os.path.exists(self.hparams["weights"]): - name, state_dict = utils.extract_encoder(self.hparams["weights"]) - - if self.hparams["classification_model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hparams['classification_model']}" - ) - - self.model = utils.load_state_dict(self.model, state_dict) From 79b351cdf18ec478d82a91d65819cc061ece2dc0 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 5 Nov 2021 17:45:17 -0500 Subject: [PATCH 2/5] Add SemanticSegmentationTask --- conf/chesapeake_cvpr.yaml | 4 +- conf/landcoverai.yaml | 3 + conf/task_defaults/chesapeake_cvpr.yaml | 6 +- conf/task_defaults/landcoverai.yaml | 3 + conf/task_defaults/naipchesapeake.yaml | 3 + conf/task_defaults/sen12ms.yaml | 1 + experiments/test_chesapeakecvpr_models.py | 2 +- tests/datasets/test_landcoverai.py | 14 +- tests/trainers/test_classification.py | 7 +- tests/trainers/test_segmentation.py | 247 +++----- torchgeo/datasets/__init__.py | 4 +- torchgeo/datasets/landcoverai.py | 6 +- torchgeo/trainers/__init__.py | 19 +- torchgeo/trainers/classification.py | 10 +- torchgeo/trainers/segmentation.py | 652 ++++------------------ train.py | 18 +- 16 files changed, 224 insertions(+), 775 deletions(-) diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index f73d2241459..8ccb74ad643 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -15,7 +15,9 @@ experiment: encoder_output_stride: 16 learning_rate: 1e-2 learning_rate_schedule_patience: 6 - class_set: 7 + in_channels: 4 + num_classes: 7 + num_filters: 256 datamodule: batch_size: 64 num_workers: 6 diff --git a/conf/landcoverai.yaml b/conf/landcoverai.yaml index a6164e6bb49..ae526557455 100644 --- a/conf/landcoverai.yaml +++ b/conf/landcoverai.yaml @@ -14,6 +14,9 @@ experiment: encoder_output_stride: 16 learning_rate: 1e-3 learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 6 + num_filters: 256 datamodule: batch_size: 32 num_workers: 6 diff --git a/conf/task_defaults/chesapeake_cvpr.yaml b/conf/task_defaults/chesapeake_cvpr.yaml index 39a58e486ff..1087c96f976 100644 --- a/conf/task_defaults/chesapeake_cvpr.yaml +++ b/conf/task_defaults/chesapeake_cvpr.yaml @@ -8,11 +8,13 @@ experiment: encoder_output_stride: 16 learning_rate: 1e-3 learning_rate_schedule_patience: 6 - class_set: 7 + in_channels: 4 + num_classes: 7 + num_filters: 256 datamodule: train_state: "de" patches_per_tile: 200 patch_size: 256 batch_size: 64 num_workers: 4 - class_set: ${experiment.module.class_set} + num_classes: ${experiment.module.num_classes} diff --git a/conf/task_defaults/landcoverai.yaml b/conf/task_defaults/landcoverai.yaml index 66f69f8350e..d3efd7bc67b 100644 --- a/conf/task_defaults/landcoverai.yaml +++ b/conf/task_defaults/landcoverai.yaml @@ -8,6 +8,9 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 verbose: false + in_channels: 3 + num_classes: 6 + num_filters: 256 datamodule: batch_size: 32 num_workers: 4 diff --git a/conf/task_defaults/naipchesapeake.yaml b/conf/task_defaults/naipchesapeake.yaml index 6e944afc088..f793fd56d39 100644 --- a/conf/task_defaults/naipchesapeake.yaml +++ b/conf/task_defaults/naipchesapeake.yaml @@ -8,6 +8,9 @@ experiment: encoder_output_stride: 16 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + in_channels: 4 + num_classes: 13 + num_filters: 64 datamodule: batch_size: 32 num_workers: 4 diff --git a/conf/task_defaults/sen12ms.yaml b/conf/task_defaults/sen12ms.yaml index 45193c2cdfe..bd4ceb0292c 100644 --- a/conf/task_defaults/sen12ms.yaml +++ b/conf/task_defaults/sen12ms.yaml @@ -9,6 +9,7 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 2 in_channels: 15 + num_classes: 11 datamodule: batch_size: 32 num_workers: 4 diff --git a/experiments/test_chesapeakecvpr_models.py b/experiments/test_chesapeakecvpr_models.py index 60bb1cc4362..012910630e8 100755 --- a/experiments/test_chesapeakecvpr_models.py +++ b/experiments/test_chesapeakecvpr_models.py @@ -12,7 +12,7 @@ import torch from torchgeo.datasets import ChesapeakeCVPRDataModule -from torchgeo.trainers import ChesapeakeCVPRSegmentationTask +from torchgeo.trainers.segmentation import ChesapeakeCVPRSegmentationTask ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]] diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 53413dc1feb..f197077f225 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import LandCoverAI, LandcoverAIDataModule +from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule def download_url(url: str, root: str, *args: str) -> None: @@ -69,22 +69,22 @@ def test_not_downloaded(self, tmp_path: Path) -> None: LandCoverAI(str(tmp_path)) -class TestLandcoverAIDataModule: +class TestLandCoverAIDataModule: @pytest.fixture(scope="class") - def datamodule(self) -> LandcoverAIDataModule: + def datamodule(self) -> LandCoverAIDataModule: root = os.path.join("tests", "data", "landcoverai") batch_size = 2 num_workers = 0 - dm = LandcoverAIDataModule(root, batch_size, num_workers) + dm = LandCoverAIDataModule(root, batch_size, num_workers) dm.prepare_data() dm.setup() return dm - def test_train_dataloader(self, datamodule: LandcoverAIDataModule) -> None: + def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None: next(iter(datamodule.train_dataloader())) - def test_val_dataloader(self, datamodule: LandcoverAIDataModule) -> None: + def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None: next(iter(datamodule.val_dataloader())) - def test_test_dataloader(self, datamodule: LandcoverAIDataModule) -> None: + def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None: next(iter(datamodule.test_dataloader())) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 3aa414f1b2c..3b602b1729b 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -15,11 +15,8 @@ from torch.utils.data import DataLoader, Dataset, TensorDataset from torchgeo.datasets import So2SatDataModule -from torchgeo.trainers import ( - ClassificationTask, - MultiLabelClassificationTask, - So2SatClassificationTask, -) +from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask +from torchgeo.trainers.classification import So2SatClassificationTask from .test_utils import mocked_log diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 65e21dc89c1..11bc50219a2 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Tuple, cast +from typing import Any, Dict, Generator, cast import pytest from _pytest.fixtures import SubRequest @@ -11,27 +11,22 @@ from torchgeo.datasets import ( ChesapeakeCVPRDataModule, - LandcoverAIDataModule, + LandCoverAIDataModule, NAIPChesapeakeDataModule, - SEN12MSDataModule, ) -from torchgeo.trainers import ( +from torchgeo.trainers import SemanticSegmentationTask +from torchgeo.trainers.segmentation import ( ChesapeakeCVPRSegmentationTask, - LandcoverAISegmentationTask, + LandCoverAISegmentationTask, NAIPChesapeakeSegmentationTask, - SEN12MSSegmentationTask, ) from .test_utils import FakeTrainer, mocked_log -class TestChesapeakeCVPRSegmentationTask: - @pytest.fixture(scope="class", params=[5, 7]) - def class_set(self, request: SubRequest) -> int: - return cast(int, request.param) - +class TestSemanticSegmentationTask: @pytest.fixture(scope="class") - def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: + def datamodule(self) -> ChesapeakeCVPRDataModule: dm = ChesapeakeCVPRDataModule( os.path.join("tests", "data", "chesapeake", "cvpr"), ["de-test"], @@ -41,7 +36,7 @@ def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: patches_per_tile=2, batch_size=2, num_workers=0, - class_set=class_set, + class_set=7, ) dm.prepare_data() dm.setup() @@ -50,14 +45,13 @@ def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: @pytest.fixture( params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) ) - def config(self, class_set: int, request: SubRequest) -> Dict[str, Any]: + def config(self, request: SubRequest) -> Dict[str, Any]: task_conf = OmegaConf.load( os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) segmentation_model, loss = request.param - task_args["class_set"] = class_set task_args["segmentation_model"] = segmentation_model task_args["loss"] = loss return task_args @@ -65,281 +59,184 @@ def config(self, class_set: int, request: SubRequest) -> Dict[str, Any]: @pytest.fixture def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> ChesapeakeCVPRSegmentationTask: - task = ChesapeakeCVPRSegmentationTask(**config) + ) -> SemanticSegmentationTask: + task = SemanticSegmentationTask(**config) trainer = FakeTrainer() monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] return task - def test_configure_optimizers(self, task: ChesapeakeCVPRSegmentationTask) -> None: + def test_configure_optimizers(self, task: SemanticSegmentationTask) -> None: out = task.configure_optimizers() assert "optimizer" in out assert "lr_scheduler" in out def test_training( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask ) -> None: batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) def test_test( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask ) -> None: batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) - def test_invalid_class_set(self, config: Dict[str, Any]) -> None: - config["class_set"] = 6 - error_message = "'class_set' must be either 5 or 7" - with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) - def test_invalid_model(self, config: Dict[str, Any]) -> None: config["segmentation_model"] = "invalid_model" error_message = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) + SemanticSegmentationTask(**config) def test_invalid_loss(self, config: Dict[str, Any]) -> None: config["loss"] = "invalid_loss" error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): - ChesapeakeCVPRSegmentationTask(**config) + SemanticSegmentationTask(**config) + +class TestChesapeakeCVPRSegmentationTask: + @pytest.fixture(scope="class", params=[5, 7]) + def class_set(self, request: SubRequest) -> int: + return cast(int, request.param) -class TestLandcoverAISegmentationTask: @pytest.fixture(scope="class") - def datamodule(self) -> LandcoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandcoverAIDataModule(root, batch_size, num_workers) + def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: + dm = ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=32, + patches_per_tile=2, + batch_size=2, + num_workers=0, + class_set=class_set, + ) dm.prepare_data() dm.setup() return dm - @pytest.fixture( - params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) - ) - def config(self, request: SubRequest) -> Dict[str, Any]: + @pytest.fixture + def config(self, class_set: int) -> Dict[str, Any]: task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "landcoverai.yaml") + os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss - task_args["verbose"] = True + task_args["num_classes"] = class_set return task_args @pytest.fixture def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> LandcoverAISegmentationTask: - task = LandcoverAISegmentationTask(**config) + ) -> ChesapeakeCVPRSegmentationTask: + task = ChesapeakeCVPRSegmentationTask(**config) trainer = FakeTrainer() monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] return task - def test_training( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - def test_validation( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask + self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test( - self, datamodule: LandcoverAIDataModule, task: LandcoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - def test_configure_optimizers(self, task: LandcoverAISegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - LandcoverAISegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - LandcoverAISegmentationTask(**config) - - -class TestNAIPChesapeakeSegmentationTask: +class TestLandCoverAISegmentationTask: @pytest.fixture(scope="class") - def datamodule(self) -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 + def datamodule(self) -> LandCoverAIDataModule: + root = os.path.join("tests", "data", "landcoverai") + batch_size = 2 + num_workers = 0 + dm = LandCoverAIDataModule(root, batch_size, num_workers) dm.prepare_data() dm.setup() return dm - @pytest.fixture(params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "ce", "jaccard"])) - def config(self, request: SubRequest) -> Dict[str, Any]: + @pytest.fixture + def config(self) -> Dict[str, Any]: task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") + os.path.join("conf", "task_defaults", "landcoverai.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss + task_args["verbose"] = True return task_args @pytest.fixture def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> NAIPChesapeakeSegmentationTask: - task = NAIPChesapeakeSegmentationTask(**config) + ) -> LandCoverAISegmentationTask: + task = LandCoverAISegmentationTask(**config) trainer = FakeTrainer() monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] return task - def test_configure_optimizers(self, task: NAIPChesapeakeSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - def test_training( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask ) -> None: batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - NAIPChesapeakeSegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - NAIPChesapeakeSegmentationTask(**config) - - -class TestSEN12MSSegmentationTask: - @pytest.fixture( - scope="class", - params=[("all", 15), ("s1", 2), ("s2-all", 13), ("s2-reduced", 6)], - ) - def bands(self, request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) +class TestNAIPChesapeakeSegmentationTask: @pytest.fixture(scope="class") - def datamodule(self, bands: Tuple[str, int]) -> SEN12MSDataModule: - root = os.path.join("tests", "data", "sen12ms") - seed = 0 - band_set = bands[0] - batch_size = 1 - num_workers = 0 - dm = SEN12MSDataModule(root, seed, band_set, batch_size, num_workers) + def datamodule(self) -> NAIPChesapeakeDataModule: + dm = NAIPChesapeakeDataModule( + os.path.join("tests", "data", "naip"), + os.path.join("tests", "data", "chesapeake", "BAYWIDE"), + batch_size=2, + num_workers=0, + ) + dm.patch_size = 32 dm.prepare_data() dm.setup() return dm - @pytest.fixture(params=["ce", "jaccard"]) - def config(self, bands: Tuple[str, int], request: SubRequest) -> Dict[str, Any]: + @pytest.fixture + def config(self) -> Dict[str, Any]: task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "sen12ms.yaml") + os.path.join("conf", "task_defaults", "naipchesapeake.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] - task_args["loss"] = request.param return task_args @pytest.fixture def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> SEN12MSSegmentationTask: - task = SEN12MSSegmentationTask(**config) + ) -> NAIPChesapeakeSegmentationTask: + task = NAIPChesapeakeSegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] return task - def test_configure_optimizers(self, task: SEN12MSSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - def test_validation( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask + self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - - def test_test( - self, datamodule: SEN12MSDataModule, task: SEN12MSSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - SEN12MSSegmentationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - SEN12MSSegmentationTask(**config) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 2c60708d792..8ef1ababe6a 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -36,7 +36,7 @@ ZipDataset, ) from .gid15 import GID15 -from .landcoverai import LandCoverAI, LandcoverAIDataModule +from .landcoverai import LandCoverAI, LandCoverAIDataModule from .landsat import ( Landsat, Landsat1, @@ -108,7 +108,7 @@ "EuroSAT", "GID15", "LandCoverAI", - "LandcoverAIDataModule", + "LandCoverAIDataModule", "LEVIRCDPlus", "PatternNet", "RESISC45", diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index e5381378bb1..2e9b58e770a 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -208,8 +208,8 @@ def _download(self) -> None: exec(split) -class LandcoverAIDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the Landcover.AI dataset. +class LandCoverAIDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the LandCover.ai dataset. Uses the train/val/test splits from the dataset. """ @@ -217,7 +217,7 @@ class LandcoverAIDataModule(pl.LightningDataModule): def __init__( self, root_dir: str, batch_size: int = 64, num_workers: int = 4, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for Landcover.AI based DataLoaders. + """Initialize a LightningDataModule for LandCover.ai based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index ad398a1e7cd..2b7ce33f4b6 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -4,29 +4,16 @@ """TorchGeo trainers.""" from .byol import BYOLTask -from .classification import ( - ClassificationTask, - MultiLabelClassificationTask, - So2SatClassificationTask, -) +from .classification import ClassificationTask, MultiLabelClassificationTask from .regression import RegressionTask -from .segmentation import ( - ChesapeakeCVPRSegmentationTask, - LandcoverAISegmentationTask, - NAIPChesapeakeSegmentationTask, - SEN12MSSegmentationTask, -) +from .segmentation import SemanticSegmentationTask __all__ = ( "BYOLTask", - "ChesapeakeCVPRSegmentationTask", "ClassificationTask", - "LandcoverAISegmentationTask", "MultiLabelClassificationTask", - "NAIPChesapeakeSegmentationTask", "RegressionTask", - "SEN12MSSegmentationTask", - "So2SatClassificationTask", + "SemanticSegmentationTask", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index e8a00459ee4..820754c7dbf 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -26,7 +26,7 @@ class ClassificationTask(pl.LightningModule): - """Abstract base class for image classification LightningModules.""" + """LightningModule for image classification.""" def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" @@ -242,7 +242,7 @@ def configure_optimizers(self) -> Dict[str, Any]: class MultiLabelClassificationTask(ClassificationTask): - """Abstract base class for multi label image classification LightningModules.""" + """LightningModule for multi-label image classification.""" def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" @@ -358,7 +358,11 @@ def test_step( # type: ignore[override] # TODO: move this functionality into ClassificationTask and remove this class class So2SatClassificationTask(ClassificationTask): - """LightningModule for training models on the So2Sat Dataset.""" + """LightningModule for training models on the So2Sat Dataset. + + .. deprecated:: 0.1 + Use :class:`ClassificationTask` instead. + """ def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 4637950b381..cef08acad90 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,7 +9,6 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import segmentation_models_pytorch as smp import torch import torch.nn as nn @@ -45,37 +44,32 @@ ) -# TODO: combine all of these classes into a single SemanticSegmentationTask -class ChesapeakeCVPRSegmentationTask(LightningModule): - """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ +class SemanticSegmentationTask(LightningModule): + """LightningModule for semantic segmentation of images.""" def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" - if self.hparams["class_set"] not in [5, 7]: - raise ValueError("'class_set' must be either 5 or 7") - num_classes = self.hparams["class_set"] - classes = range(1, self.hparams["class_set"]) if self.hparams["segmentation_model"] == "unet": self.model = smp.Unet( encoder_name=self.hparams["encoder_name"], encoder_weights=self.hparams["encoder_weights"], - in_channels=4, - classes=num_classes, + in_channels=self.hparams["in_channels"], + classes=self.hparams["num_classes"], ) elif self.hparams["segmentation_model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hparams["encoder_name"], encoder_weights=self.hparams["encoder_weights"], - in_channels=4, - classes=num_classes, + in_channels=self.hparams["in_channels"], + classes=self.hparams["num_classes"], ) elif self.hparams["segmentation_model"] == "fcn": - self.model = FCN(in_channels=4, classes=num_classes, num_filters=256) + self.model = FCN( + in_channels=self.hparams["in_channels"], + classes=self.hparams["num_classes"], + num_filters=self.hparams["num_filters"], + ) else: raise ValueError( f"Model type '{self.hparams['segmentation_model']}' is not valid." @@ -86,7 +80,9 @@ def config_task(self) -> None: ignore_index=0 ) elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=classes) + self.loss = smp.losses.JaccardLoss( + mode="multiclass", classes=self.hparams["num_classes"] + ) elif self.hparams["loss"] == "focal": self.loss = smp.losses.FocalLoss( "multiclass", ignore_index=0, normalized=True @@ -102,6 +98,8 @@ def __init__(self, **kwargs: Any) -> None: encoder_name: Name of the encoder model backbone to use encoder_weights: None or "imagenet" to use imagenet pretrained weights in the encoder model + in_channels: Number of channels in input image + num_classes: Number of semantic classes to predict loss: Name of the loss function Raises: @@ -114,8 +112,8 @@ def __init__(self, **kwargs: Any) -> None: self.train_metrics = MetricCollection( [ - Accuracy(num_classes=self.hparams["class_set"], ignore_index=0), - IoU(num_classes=self.hparams["class_set"], ignore_index=0), + Accuracy(num_classes=self.hparams["num_classes"], ignore_index=0), + IoU(num_classes=self.hparams["num_classes"], ignore_index=0), ], prefix="train_", ) @@ -190,48 +188,6 @@ def validation_step( # type: ignore[override] self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: - cmap = None - if self.hparams["class_set"] == 5: - cmap = CMAP_5 - else: - cmap = CMAP_7 - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img[:, :, :3]) - axs[0].axis("off") - axs[1].imshow( - mask, - vmin=0, - vmax=self.hparams["class_set"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[1].axis("off") - axs[2].imshow( - pred, - vmin=0, - vmax=self.hparams["class_set"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[2].axis("off") - plt.tight_layout() - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. @@ -291,94 +247,96 @@ def configure_optimizers(self) -> Dict[str, Any]: } -class LandcoverAISegmentationTask(pl.LightningModule): - """LightningModule for training models on the Landcover.AI Dataset. +# TODO: refactor any differences between these classes and SemanticSegmentationTask +# so that these classes are no longer needed. +class ChesapeakeCVPRSegmentationTask(SemanticSegmentationTask): + """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - if self.hparams["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=3, - classes=6, - ) - elif self.hparams["segmentation_model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=3, - classes=6, - ) - elif self.hparams["segmentation_model"] == "fcn": - self.model = FCN(in_channels=3, classes=6, num_filters=256) - else: - raise ValueError( - f"Model type '{self.hparams['segmentation_model']}' is not valid." - ) - - if self.hparams["loss"] == "ce": - self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] - ignore_index=0 - ) - elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass", classes=range(1, 6)) - elif self.hparams["loss"] == "focal": - self.loss = smp.losses.FocalLoss( - "multiclass", ignore_index=0, normalized=True - ) - else: - raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. + Logs the first 10 validation samples to tensorboard as images with 3 subplots + showing the image, mask, and predictions. - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - encoder_output_stride: The output stride parameter in DeepLabV3+ models - loss: Name of the loss function + Args: + batch: Current batch + batch_idx: Index of current batch """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) - self.config_task() + loss = self.loss(y_hat, y) - self.train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["input", "mask"], - ) + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) - self.train_metrics = MetricCollection( - [ - Accuracy(num_classes=6, ignore_index=0), - IoU(num_classes=6, ignore_index=0), - ], - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") + if batch_idx < 10: + cmap = None + if self.hparams["num_classes"] == 5: + cmap = CMAP_5 + else: + cmap = CMAP_7 + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + batch["image"][0].cpu().numpy(), 0, 3 + ) + mask = batch["mask"][0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img[:, :, :3]) + axs[0].axis("off") + axs[1].imshow( + mask, + vmin=0, + vmax=self.hparams["num_classes"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[1].axis("off") + axs[2].imshow( + pred, + vmin=0, + vmax=self.hparams["num_classes"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[2].axis("off") + plt.tight_layout() - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() - Args: - x: input image - Returns: - prediction - """ - return self.model(x) +class LandCoverAISegmentationTask(SemanticSegmentationTask): + """LightningModule for training models on the Landcover.AI Dataset. + + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. + """ + + # TODO: move this to LandCoverAIDataModule + train_augmentations = K.AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + data_keys=["input", "mask"], + ) def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int @@ -410,15 +368,6 @@ def training_step( # type: ignore[override] return cast(Tensor, loss) - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: @@ -466,181 +415,14 @@ def validation_step( # type: ignore[override] plt.close() - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - Args: - outputs: list of items returned by validation_step - """ - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"].long().squeeze() - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - }, - } - - -class NAIPChesapeakeSegmentationTask(pl.LightningModule): +class NAIPChesapeakeSegmentationTask(SemanticSegmentationTask): """LightningModule for training models on the NAIP and Chesapeake datasets. - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. """ - in_channels = 4 - classes = 13 - # TODO: tune this hyperparam - num_filters = 64 - - def config_task(self, kwargs: Any) -> None: - """Configures the task based on kwargs parameters.""" - if kwargs["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=kwargs["encoder_name"], - encoder_weights=kwargs["encoder_weights"], - in_channels=self.in_channels, - classes=self.classes, - ) - elif kwargs["segmentation_model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=kwargs["encoder_name"], - encoder_weights=kwargs["encoder_weights"], - encoder_output_stride=kwargs["encoder_output_stride"], - in_channels=self.in_channels, - classes=self.classes, - ) - elif kwargs["segmentation_model"] == "fcn": - self.model = FCN(self.in_channels, self.classes, self.num_filters) - else: - raise ValueError( - f"Model type '{kwargs['segmentation_model']}' is not valid." - ) - - if kwargs["loss"] == "ce": - self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] - elif kwargs["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass") - else: - raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - encoder_output_stride: The output stride parameter in DeepLabV3+ models - loss: Name of the loss function - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - - self.config_task(kwargs) - - self.train_accuracy = Accuracy() - self.val_accuracy = Accuracy() - self.test_accuracy = Accuracy() - - self.train_iou = IoU(num_classes=self.classes) - self.val_iou = IoU(num_classes=self.classes) - self.test_iou = IoU(num_classes=self.classes) - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: input image - - Returns: - prediction - """ - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: current batch - batch_idx: index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_accuracy(y_hat_hard, y) - self.train_iou(y_hat_hard, y) - - return cast(Tensor, loss) - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log("train_acc", self.train_accuracy.compute()) - self.log("train_iou", self.train_iou.compute()) - self.train_accuracy.reset() - self.train_iou.reset() - def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: @@ -659,8 +441,7 @@ def validation_step( # type: ignore[override] # by default, the test and validation steps only log per *epoch* self.log("val_loss", loss) - self.val_accuracy(y_hat_hard, y) - self.val_iou(y_hat_hard, y) + self.val_metrics(y_hat_hard, y) if batch_idx < 10: # Render the image, ground truth mask, and predicted mask for the first @@ -686,236 +467,3 @@ def validation_step( # type: ignore[override] ) plt.close() - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log("val_acc", self.val_accuracy.compute()) - self.log("val_iou", self.val_iou.compute()) - self.val_accuracy.reset() - self.val_iou.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: current batch - batch_idx: index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss) - self.test_accuracy(y_hat_hard, y) - self.test_iou(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log("test_acc", self.test_accuracy.compute()) - self.log("test_iou", self.test_iou.compute()) - self.test_accuracy.reset() - self.test_iou.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - "verbose": True, - }, - } - - -class SEN12MSSegmentationTask(pl.LightningModule): - """LightningModule for training models on the SEN12MS Dataset. - - This allows using arbitrary models and losses from the - ``pytorch_segmentation_models`` package. - """ - - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" - if self.hparams["segmentation_model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hparams["encoder_name"], - encoder_weights=self.hparams["encoder_weights"], - in_channels=self.hparams["in_channels"], - classes=11, - ) - else: - raise ValueError( - f"Model type '{self.hparams['segmentation_model']}' is not valid." - ) - - if self.hparams["loss"] == "ce": - self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined] - elif self.hparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss(mode="multiclass") - else: - raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - segmentation_model: Name of the segmentation model type to use - encoder_name: Name of the encoder model backbone to use - encoder_weights: None or "imagenet" to use imagenet pretrained weights in - the encoder model - loss: Name of the loss function - """ - super().__init__() - self.save_hyperparameters() # creates `self.hparams` from kwargs - - self.config_task() - - self.train_metrics = MetricCollection([Accuracy()], prefix="train_") - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, x: Tensor) -> Any: # type: ignore[override] - """Forward pass of the model. - - Args: - x: input image - - Returns: - prediction - """ - return self.model(x) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) - - def training_epoch_end(self, outputs: Any) -> None: - """Logs epoch level training metrics. - - Args: - outputs: list of items returned by training_step - """ - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - def validation_epoch_end(self, outputs: Any) -> None: - """Logs epoch level validation metrics. - - Args: - outputs: list of items returned by validation_step - """ - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) - - def test_epoch_end(self, outputs: Any) -> None: - """Logs epoch level test metrics. - - Args: - outputs: list of items returned by test_step - """ - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def configure_optimizers(self) -> Dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - a "lr dict" according to the pytorch lightning documentation -- - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, patience=self.hparams["learning_rate_schedule_patience"] - ), - "monitor": "val_loss", - }, - } diff --git a/train.py b/train.py index 781dcd592dc..668be45852f 100755 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ ChesapeakeCVPRDataModule, COWCCountingDataModule, CycloneDataModule, - LandcoverAIDataModule, + LandCoverAIDataModule, NAIPChesapeakeDataModule, RESISC45DataModule, SEN12MSDataModule, @@ -27,14 +27,16 @@ ) from torchgeo.trainers import ( BYOLTask, - ChesapeakeCVPRSegmentationTask, ClassificationTask, - LandcoverAISegmentationTask, MultiLabelClassificationTask, - NAIPChesapeakeSegmentationTask, RegressionTask, - SEN12MSSegmentationTask, - So2SatClassificationTask, + SemanticSegmentationTask, +) +from torchgeo.trainers.classification import So2SatClassificationTask +from torchgeo.trainers.segmentation import ( + ChesapeakeCVPRSegmentationTask, + LandCoverAISegmentationTask, + NAIPChesapeakeSegmentationTask, ) TASK_TO_MODULES_MAPPING: Dict[ @@ -45,10 +47,10 @@ "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), "cowc_counting": (RegressionTask, COWCCountingDataModule), "cyclone": (RegressionTask, CycloneDataModule), - "landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule), + "landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule), "naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule), "resisc45": (ClassificationTask, RESISC45DataModule), - "sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule), + "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), "so2sat": (So2SatClassificationTask, So2SatDataModule), "ucmerced": (ClassificationTask, UCMercedDataModule), } From b71693055e57fb0c82d8e3f42946230fa9ce41e0 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 5 Nov 2021 17:52:13 -0500 Subject: [PATCH 3/5] Fix doc tests --- docs/api/datasets.rst | 2 +- torchgeo/trainers/segmentation.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 1e6131b60f2..a333f194cf2 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -120,7 +120,7 @@ LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LandCoverAI -.. autoclass:: LandcoverAIDataModule +.. autoclass:: LandCoverAIDataModule LEVIR-CD+ (LEVIR Change Detection +) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index cef08acad90..5c52ff3ba7a 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -49,7 +49,6 @@ class SemanticSegmentationTask(LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" - if self.hparams["segmentation_model"] == "unet": self.model = smp.Unet( encoder_name=self.hparams["encoder_name"], From e051d40debd0e4d7423c8ca369e0917250417b73 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 6 Nov 2021 18:32:13 -0500 Subject: [PATCH 4/5] Keep dataset-specific tasks in separate files --- experiments/test_chesapeakecvpr_models.py | 2 +- tests/trainers/test_chesapeake.py | 65 ++++++ tests/trainers/test_classification.py | 106 +--------- tests/trainers/test_landcoverai.py | 60 ++++++ tests/trainers/test_naipchesapeake.py | 55 +++++ tests/trainers/test_segmentation.py | 148 +------------ tests/trainers/test_so2sat.py | 111 ++++++++++ torchgeo/trainers/chesapeake.py | 104 +++++++++ torchgeo/trainers/landcoverai.py | 111 ++++++++++ torchgeo/trainers/naipchesapeake.py | 66 ++++++ torchgeo/trainers/segmentation.py | 245 ---------------------- torchgeo/trainers/so2sat.py | 90 ++++++++ train.py | 10 +- 13 files changed, 669 insertions(+), 504 deletions(-) create mode 100644 tests/trainers/test_chesapeake.py create mode 100644 tests/trainers/test_landcoverai.py create mode 100644 tests/trainers/test_naipchesapeake.py create mode 100644 tests/trainers/test_so2sat.py create mode 100644 torchgeo/trainers/chesapeake.py create mode 100644 torchgeo/trainers/landcoverai.py create mode 100644 torchgeo/trainers/naipchesapeake.py create mode 100644 torchgeo/trainers/so2sat.py diff --git a/experiments/test_chesapeakecvpr_models.py b/experiments/test_chesapeakecvpr_models.py index 012910630e8..249dd20bc7d 100755 --- a/experiments/test_chesapeakecvpr_models.py +++ b/experiments/test_chesapeakecvpr_models.py @@ -12,7 +12,7 @@ import torch from torchgeo.datasets import ChesapeakeCVPRDataModule -from torchgeo.trainers.segmentation import ChesapeakeCVPRSegmentationTask +from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]] diff --git a/tests/trainers/test_chesapeake.py b/tests/trainers/test_chesapeake.py new file mode 100644 index 00000000000..a9c95907dbc --- /dev/null +++ b/tests/trainers/test_chesapeake.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, cast + +import pytest +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask + +from .test_utils import FakeTrainer, mocked_log + + +class TestChesapeakeCVPRSegmentationTask: + @pytest.fixture(scope="class", params=[5, 7]) + def class_set(self, request: SubRequest) -> int: + return cast(int, request.param) + + @pytest.fixture(scope="class") + def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: + dm = ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=32, + patches_per_tile=2, + batch_size=2, + num_workers=0, + class_set=class_set, + ) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture + def config(self, class_set: int) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["num_classes"] = class_set + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> ChesapeakeCVPRSegmentationTask: + task = ChesapeakeCVPRSegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_validation( + self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 3b602b1729b..55eac33f155 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Optional, Tuple, cast +from typing import Any, Dict, Generator, Optional, cast import pytest import pytorch_lightning as pl @@ -14,9 +14,7 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset, TensorDataset -from torchgeo.datasets import So2SatDataModule from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask -from torchgeo.trainers.classification import So2SatClassificationTask from .test_utils import mocked_log @@ -251,105 +249,3 @@ def test_invalid_loss(self, config: Dict[str, Any]) -> None: error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): MultiLabelClassificationTask(**config) - - -class TestSo2SatClassificationTask: - @pytest.fixture(scope="class", params=[("rgb", 3), ("s2", 10)]) - def bands(self, request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - @pytest.fixture(scope="class", params=[True, False]) - def datamodule( - self, bands: Tuple[str, int], request: SubRequest - ) -> So2SatDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule( - root, batch_size, num_workers, band_set, unsupervised_mode - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture( - params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) - ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] - loss, weights = request.param - task_args["loss"] = loss - task_args["weights"] = weights - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> So2SatClassificationTask: - task = So2SatClassificationTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: So2SatClassificationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: So2SatDataModule, task: So2SatClassificationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_pretrained(self, checkpoint: str) -> None: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["weights"] = checkpoint - with pytest.warns(UserWarning): - So2SatClassificationTask(**task_args) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["classification_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_weights(self, config: Dict[str, Any]) -> None: - config["weights"] = "invalid_weights" - error_message = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) - - def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> None: - config["weights"] = checkpoint - config["classification_model"] = "resnet50" - error_message = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=error_message): - So2SatClassificationTask(**config) diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py new file mode 100644 index 00000000000..ed8d4448a4a --- /dev/null +++ b/tests/trainers/test_landcoverai.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, cast + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import LandCoverAIDataModule +from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask + +from .test_utils import FakeTrainer, mocked_log + + +class TestLandCoverAISegmentationTask: + @pytest.fixture(scope="class") + def datamodule(self) -> LandCoverAIDataModule: + root = os.path.join("tests", "data", "landcoverai") + batch_size = 2 + num_workers = 0 + dm = LandCoverAIDataModule(root, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture + def config(self) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "landcoverai.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["verbose"] = True + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> LandCoverAISegmentationTask: + task = LandCoverAISegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_training( + self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) diff --git a/tests/trainers/test_naipchesapeake.py b/tests/trainers/test_naipchesapeake.py new file mode 100644 index 00000000000..3b8cce5aca0 --- /dev/null +++ b/tests/trainers/test_naipchesapeake.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, cast + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import NAIPChesapeakeDataModule +from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask + +from .test_utils import FakeTrainer, mocked_log + + +class TestNAIPChesapeakeSegmentationTask: + @pytest.fixture(scope="class") + def datamodule(self) -> NAIPChesapeakeDataModule: + dm = NAIPChesapeakeDataModule( + os.path.join("tests", "data", "naip"), + os.path.join("tests", "data", "chesapeake", "BAYWIDE"), + batch_size=2, + num_workers=0, + ) + dm.patch_size = 32 + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture + def config(self) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "naipchesapeake.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> NAIPChesapeakeSegmentationTask: + task = NAIPChesapeakeSegmentationTask(**config) + trainer = FakeTrainer() + monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_validation( + self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 11bc50219a2..058f94170b7 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,17 +9,8 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import ( - ChesapeakeCVPRDataModule, - LandCoverAIDataModule, - NAIPChesapeakeDataModule, -) +from torchgeo.datasets import ChesapeakeCVPRDataModule from torchgeo.trainers import SemanticSegmentationTask -from torchgeo.trainers.segmentation import ( - ChesapeakeCVPRSegmentationTask, - LandCoverAISegmentationTask, - NAIPChesapeakeSegmentationTask, -) from .test_utils import FakeTrainer, mocked_log @@ -103,140 +94,3 @@ def test_invalid_loss(self, config: Dict[str, Any]) -> None: error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): SemanticSegmentationTask(**config) - - -class TestChesapeakeCVPRSegmentationTask: - @pytest.fixture(scope="class", params=[5, 7]) - def class_set(self, request: SubRequest) -> int: - return cast(int, request.param) - - @pytest.fixture(scope="class") - def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=class_set, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self, class_set: int) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["num_classes"] = class_set - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> ChesapeakeCVPRSegmentationTask: - task = ChesapeakeCVPRSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - -class TestLandCoverAISegmentationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> LandCoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandCoverAIDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "landcoverai.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["verbose"] = True - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> LandCoverAISegmentationTask: - task = LandCoverAISegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_training( - self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - -class TestNAIPChesapeakeSegmentationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "naipchesapeake.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> NAIPChesapeakeSegmentationTask: - task = NAIPChesapeakeSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_validation( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) diff --git a/tests/trainers/test_so2sat.py b/tests/trainers/test_so2sat.py new file mode 100644 index 00000000000..7f117df6878 --- /dev/null +++ b/tests/trainers/test_so2sat.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, Generator, Tuple, cast + +import pytest +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from omegaconf import OmegaConf + +from torchgeo.datasets import So2SatDataModule +from torchgeo.trainers.so2sat import So2SatClassificationTask + +from .test_utils import mocked_log + + +class TestSo2SatClassificationTask: + @pytest.fixture(scope="class", params=[("rgb", 3), ("s2", 10)]) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(scope="class", params=[True, False]) + def datamodule( + self, bands: Tuple[str, int], request: SubRequest + ) -> So2SatDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "so2sat") + batch_size = 2 + num_workers = 0 + dm = So2SatDataModule( + root, batch_size, num_workers, band_set, unsupervised_mode + ) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture( + params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) + ) + def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: + task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["in_channels"] = bands[1] + loss, weights = request.param + task_args["loss"] = loss + task_args["weights"] = weights + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> So2SatClassificationTask: + task = So2SatClassificationTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_configure_optimizers(self, task: So2SatClassificationTask) -> None: + out = task.configure_optimizers() + assert "optimizer" in out + assert "lr_scheduler" in out + + def test_training( + self, datamodule: So2SatDataModule, task: So2SatClassificationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: So2SatDataModule, task: So2SatClassificationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, datamodule: So2SatDataModule, task: So2SatClassificationTask + ) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_pretrained(self, checkpoint: str) -> None: + task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["weights"] = checkpoint + with pytest.warns(UserWarning): + So2SatClassificationTask(**task_args) + + def test_invalid_model(self, config: Dict[str, Any]) -> None: + config["classification_model"] = "invalid_model" + error_message = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) + + def test_invalid_weights(self, config: Dict[str, Any]) -> None: + config["weights"] = "invalid_weights" + error_message = "Weight type 'invalid_weights' is not valid." + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) + + def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> None: + config["weights"] = checkpoint + config["classification_model"] = "resnet50" + error_message = "Trying to load resnet18 weights into a resnet50" + with pytest.raises(ValueError, match=error_message): + So2SatClassificationTask(**config) diff --git a/torchgeo/trainers/chesapeake.py b/torchgeo/trainers/chesapeake.py new file mode 100644 index 00000000000..15a2e247dee --- /dev/null +++ b/torchgeo/trainers/chesapeake.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Segmentation tasks.""" + +from typing import Any, Dict + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] + +from ..datasets import Chesapeake7 +from .segmentation import SemanticSegmentationTask + +# TODO: move the color maps to a dataset object +CMAP_7 = matplotlib.colors.ListedColormap( + [np.array(Chesapeake7.cmap[i]) / 255.0 for i in range(7)] +) +CMAP_5 = matplotlib.colors.ListedColormap( + np.array( + [ + (0, 0, 0, 0), + (0, 197, 255, 255), + (38, 115, 0, 255), + (163, 255, 115, 255), + (156, 156, 156, 255), + ] + ) + / 255.0 +) + + +# TODO: move this functionality into SemanticSegmentationTask and remove this class +class ChesapeakeCVPRSegmentationTask(SemanticSegmentationTask): + """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. + + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. + """ + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Logs the first 10 validation samples to tensorboard as images with 3 subplots + showing the image, mask, and predictions. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + if batch_idx < 10: + cmap = None + if self.hparams["num_classes"] == 5: + cmap = CMAP_5 + else: + cmap = CMAP_7 + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + batch["image"][0].cpu().numpy(), 0, 3 + ) + mask = batch["mask"][0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img[:, :, :3]) + axs[0].axis("off") + axs[1].imshow( + mask, + vmin=0, + vmax=self.hparams["num_classes"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[1].axis("off") + axs[2].imshow( + pred, + vmin=0, + vmax=self.hparams["num_classes"] - 1, + cmap=cmap, + interpolation="none", + ) + axs[2].axis("off") + plt.tight_layout() + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py new file mode 100644 index 00000000000..85d4fd81381 --- /dev/null +++ b/torchgeo/trainers/landcoverai.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Segmentation tasks.""" + +from typing import Any, Dict, cast + +import kornia.augmentation as K +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] + +from .segmentation import SemanticSegmentationTask + + +# TODO: move this functionality into SemanticSegmentationTask and remove this class +class LandCoverAISegmentationTask(SemanticSegmentationTask): + """LightningModule for training models on the Landcover.AI Dataset. + + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. + """ + + # TODO: move this to LandCoverAIDataModule + train_augmentations = K.AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + data_keys=["input", "mask"], + ) + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step - reports average accuracy and average IoU. + + Args: + batch: Current batch + batch_idx: Index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["mask"] + with torch.no_grad(): + x, y = self.train_augmentations(x, y) + y = y.long().squeeze() + + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Logs the first 10 validation samples to tensorboard as images with 3 subplots + showing the image, mask, and predictions. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["mask"].long().squeeze() + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + if batch_idx < 10 and self.hparams["verbose"]: + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + x[0].cpu().numpy(), 0, 3 + ) + mask = y[0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=5) + axs[1].axis("off") + axs[2].imshow(pred, vmin=0, vmax=5) + axs[2].axis("off") + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + + plt.close() diff --git a/torchgeo/trainers/naipchesapeake.py b/torchgeo/trainers/naipchesapeake.py new file mode 100644 index 00000000000..b5480627772 --- /dev/null +++ b/torchgeo/trainers/naipchesapeake.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Segmentation tasks.""" + +from typing import Any, Dict + +import matplotlib.pyplot as plt +import numpy as np +from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] + +from .segmentation import SemanticSegmentationTask + + +# TODO: move this functionality into SemanticSegmentationTask and remove this class +class NAIPChesapeakeSegmentationTask(SemanticSegmentationTask): + """LightningModule for training models on the NAIP and Chesapeake datasets. + + .. deprecated: 0.1 + Use :class:`SemanticSegmentationTask` instead. + """ + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step - reports average accuracy and average IoU. + + Args: + batch: current batch + batch_idx: index of current batch + """ + x = batch["image"] + y = batch["mask"] + y_hat = self.forward(x) + y_hat_hard = y_hat.argmax(dim=1) + + loss = self.loss(y_hat, y) + + # by default, the test and validation steps only log per *epoch* + self.log("val_loss", loss) + self.val_metrics(y_hat_hard, y) + + if batch_idx < 10: + # Render the image, ground truth mask, and predicted mask for the first + # image in the batch + img = np.rollaxis( # convert image to channels last format + batch["image"][0].cpu().numpy(), 0, 3 + ) + mask = batch["mask"][0].cpu().numpy() + pred = y_hat_hard[0].cpu().numpy() + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(img) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4) + axs[1].axis("off") + axs[2].imshow(pred, vmin=0, vmax=4) + axs[2].axis("off") + + # the SummaryWriter is a tensorboard object, see: + # https://pytorch.org/docs/stable/tensorboard.html# + summary_writer: SummaryWriter = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + + plt.close() diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 5c52ff3ba7a..47e0d709ccd 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -5,10 +5,6 @@ from typing import Any, Dict, cast -import kornia.augmentation as K -import matplotlib -import matplotlib.pyplot as plt -import numpy as np import segmentation_models_pytorch as smp import torch import torch.nn as nn @@ -16,33 +12,14 @@ from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] from torchmetrics import Accuracy, IoU, MetricCollection -from ..datasets import Chesapeake7 from ..models import FCN # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 DataLoader.__module__ = "torch.utils.data" -# TODO: move the color maps to a dataset object -CMAP_7 = matplotlib.colors.ListedColormap( - [np.array(Chesapeake7.cmap[i]) / 255.0 for i in range(7)] -) -CMAP_5 = matplotlib.colors.ListedColormap( - np.array( - [ - (0, 0, 0, 0), - (0, 197, 255, 255), - (38, 115, 0, 255), - (163, 255, 115, 255), - (156, 156, 156, 255), - ] - ) - / 255.0 -) - class SemanticSegmentationTask(LightningModule): """LightningModule for semantic segmentation of images.""" @@ -244,225 +221,3 @@ def configure_optimizers(self) -> Dict[str, Any]: "monitor": "val_loss", }, } - - -# TODO: refactor any differences between these classes and SemanticSegmentationTask -# so that these classes are no longer needed. -class ChesapeakeCVPRSegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - - .. deprecated: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10: - cmap = None - if self.hparams["num_classes"] == 5: - cmap = CMAP_5 - else: - cmap = CMAP_7 - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img[:, :, :3]) - axs[0].axis("off") - axs[1].imshow( - mask, - vmin=0, - vmax=self.hparams["num_classes"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[1].axis("off") - axs[2].imshow( - pred, - vmin=0, - vmax=self.hparams["num_classes"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[2].axis("off") - plt.tight_layout() - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - - -class LandCoverAISegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the Landcover.AI Dataset. - - .. deprecated: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - # TODO: move this to LandCoverAIDataModule - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["input", "mask"], - ) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"] - with torch.no_grad(): - x, y = self.train_augmentations(x, y) - y = y.long().squeeze() - - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"].long().squeeze() - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10 and self.hparams["verbose"]: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - x[0].cpu().numpy(), 0, 3 - ) - mask = y[0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=5) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=5) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() - - -class NAIPChesapeakeSegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the NAIP and Chesapeake datasets. - - .. deprecated: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Args: - batch: current batch - batch_idx: index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("val_loss", loss) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=4) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=4) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() diff --git a/torchgeo/trainers/so2sat.py b/torchgeo/trainers/so2sat.py new file mode 100644 index 00000000000..3bda62acb9f --- /dev/null +++ b/torchgeo/trainers/so2sat.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Classification tasks.""" + +import os + +import torch +import torch.nn as nn +import torchvision.models +from torch.nn.modules import Conv2d, Linear + +from . import utils +from .classification import ClassificationTask + + +# TODO: move this functionality into ClassificationTask and remove this class +class So2SatClassificationTask(ClassificationTask): + """LightningModule for training models on the So2Sat Dataset. + + .. deprecated:: 0.1 + Use :class:`ClassificationTask` instead. + """ + + def config_model(self) -> None: + """Configures the model based on kwargs parameters passed to the constructor.""" + in_channels = self.hparams["in_channels"] + + pretrained = False + if not os.path.exists(self.hparams["weights"]): + if self.hparams["weights"] == "imagenet": + pretrained = True + elif self.hparams["weights"] == "random": + pretrained = False + else: + raise ValueError( + f"Weight type '{self.hparams['weights']}' is not valid." + ) + + # Create the model + if "resnet" in self.hparams["classification_model"]: + self.model = getattr( + torchvision.models.resnet, self.hparams["classification_model"] + )(pretrained=pretrained) + in_features = self.model.fc.in_features + self.model.fc = Linear( + in_features, out_features=self.hparams["num_classes"] + ) + + # Update first layer + if in_channels != 3: + w_old = None + if pretrained: + w_old = torch.clone( # type: ignore[attr-defined] + self.model.conv1.weight + ).detach() + # Create the new layer + self.model.conv1 = Conv2d( + in_channels, 64, kernel_size=7, stride=1, padding=2, bias=False + ) + nn.init.kaiming_normal_( # type: ignore[no-untyped-call] + self.model.conv1.weight, mode="fan_out", nonlinearity="relu" + ) + + # We copy over the pretrained RGB weights + if pretrained: + w_new = torch.clone( # type: ignore[attr-defined] + self.model.conv1.weight + ).detach() + w_new[:, :3, :, :] = w_old + self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 + w_new + ) + else: + raise ValueError( + f"Model type '{self.hparams['classification_model']}' is not valid." + ) + + # Load pretrained weights checkpoint weights + if "resnet" in self.hparams["classification_model"]: + if os.path.exists(self.hparams["weights"]): + name, state_dict = utils.extract_encoder(self.hparams["weights"]) + + if self.hparams["classification_model"] != name: + raise ValueError( + f"Trying to load {name} weights into a " + f"{self.hparams['classification_model']}" + ) + + self.model = utils.load_state_dict(self.model, state_dict) diff --git a/train.py b/train.py index 668be45852f..d776bcf85b0 100755 --- a/train.py +++ b/train.py @@ -32,12 +32,10 @@ RegressionTask, SemanticSegmentationTask, ) -from torchgeo.trainers.classification import So2SatClassificationTask -from torchgeo.trainers.segmentation import ( - ChesapeakeCVPRSegmentationTask, - LandCoverAISegmentationTask, - NAIPChesapeakeSegmentationTask, -) +from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask +from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask +from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask +from torchgeo.trainers.so2sat import So2SatClassificationTask TASK_TO_MODULES_MAPPING: Dict[ str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] From 45f2feaf1284dc517c965018179fc62170d044ce Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 6 Nov 2021 18:46:20 -0500 Subject: [PATCH 5/5] Remove duplicate So2Sat trainer --- torchgeo/trainers/classification.py | 77 ----------------------------- 1 file changed, 77 deletions(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 820754c7dbf..c0b447b7418 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -10,7 +10,6 @@ import timm import torch import torch.nn as nn -import torchvision.models from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor from torch.nn.modules import Conv2d, Linear @@ -354,79 +353,3 @@ def test_step( # type: ignore[override] # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) self.test_metrics(y_hat_hard, y) - - -# TODO: move this functionality into ClassificationTask and remove this class -class So2SatClassificationTask(ClassificationTask): - """LightningModule for training models on the So2Sat Dataset. - - .. deprecated:: 0.1 - Use :class:`ClassificationTask` instead. - """ - - def config_model(self) -> None: - """Configures the model based on kwargs parameters passed to the constructor.""" - in_channels = self.hparams["in_channels"] - - pretrained = False - if not os.path.exists(self.hparams["weights"]): - if self.hparams["weights"] == "imagenet": - pretrained = True - elif self.hparams["weights"] == "random": - pretrained = False - else: - raise ValueError( - f"Weight type '{self.hparams['weights']}' is not valid." - ) - - # Create the model - if "resnet" in self.hparams["classification_model"]: - self.model = getattr( - torchvision.models.resnet, self.hparams["classification_model"] - )(pretrained=pretrained) - in_features = self.model.fc.in_features - self.model.fc = Linear( - in_features, out_features=self.hparams["num_classes"] - ) - - # Update first layer - if in_channels != 3: - w_old = None - if pretrained: - w_old = torch.clone( # type: ignore[attr-defined] - self.model.conv1.weight - ).detach() - # Create the new layer - self.model.conv1 = Conv2d( - in_channels, 64, kernel_size=7, stride=1, padding=2, bias=False - ) - nn.init.kaiming_normal_( # type: ignore[no-untyped-call] - self.model.conv1.weight, mode="fan_out", nonlinearity="relu" - ) - - # We copy over the pretrained RGB weights - if pretrained: - w_new = torch.clone( # type: ignore[attr-defined] - self.model.conv1.weight - ).detach() - w_new[:, :3, :, :] = w_old - self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 - w_new - ) - else: - raise ValueError( - f"Model type '{self.hparams['classification_model']}' is not valid." - ) - - # Load pretrained weights checkpoint weights - if "resnet" in self.hparams["classification_model"]: - if os.path.exists(self.hparams["weights"]): - name, state_dict = utils.extract_encoder(self.hparams["weights"]) - - if self.hparams["classification_model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hparams['classification_model']}" - ) - - self.model = utils.load_state_dict(self.model, state_dict)