From 42b9a6dbd218de0d475fe1d2d1f26f297c03067b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 1 Jan 2022 14:14:19 -0600 Subject: [PATCH] Remove dataset-specific trainers (#286) * Remove dataset-specific trainers * Collation functions will be new in 0.2.0 * Clarify arg docstring * Style fixes * Remove files forgotten in rebase * Fix bug in unbind_samples, add tests * Fix bugs in datamodule augmentations * Increase coverage for datamodules * Fix bugs in logger plotting, properly test * Fix tests * Increase coverage of trainers * Use datamodule plot instead of dataset plot * Skip datamodules without tests * Plot predictions * Fix ClassificationTask tests * Fix SemanticSegmentationTask tests * EAFP -> LBYL * Ensure that tensors are on the CPU before plotting --- conf/task_defaults/eurosat.yaml | 2 +- conf/task_defaults/resisc45.yaml | 2 +- conf/task_defaults/ucmerced.yaml | 2 +- docs/api/datasets.rst | 1 + tests/datasets/test_utils.py | 22 ++++- tests/trainers/test_chesapeake.py | 64 ------------- tests/trainers/test_classification.py | 36 +++++++ tests/trainers/test_landcoverai.py | 67 ------------- tests/trainers/test_naipchesapeake.py | 55 ----------- tests/trainers/test_regression.py | 17 ++++ tests/trainers/test_resisc45.py | 54 ----------- tests/trainers/test_segmentation.py | 17 ++++ tests/trainers/test_utils.py | 21 ----- torchgeo/datamodules/bigearthnet.py | 5 + torchgeo/datamodules/cowc.py | 5 + torchgeo/datamodules/etci2021.py | 5 + torchgeo/datamodules/eurosat.py | 5 + torchgeo/datamodules/landcoverai.py | 47 ++++++++- torchgeo/datamodules/resisc45.py | 42 +++++++++ torchgeo/datamodules/ucmerced.py | 5 + torchgeo/datasets/__init__.py | 9 +- torchgeo/datasets/cowc.py | 2 +- torchgeo/datasets/utils.py | 41 ++++++++ torchgeo/trainers/chesapeake.py | 104 -------------------- torchgeo/trainers/classification.py | 31 ++++++ torchgeo/trainers/landcoverai.py | 131 -------------------------- torchgeo/trainers/naipchesapeake.py | 66 ------------- torchgeo/trainers/regression.py | 17 ++++ torchgeo/trainers/resisc45.py | 70 -------------- torchgeo/trainers/segmentation.py | 16 ++++ train.py | 16 ++-- 31 files changed, 327 insertions(+), 650 deletions(-) delete mode 100644 tests/trainers/test_chesapeake.py delete mode 100644 tests/trainers/test_landcoverai.py delete mode 100644 tests/trainers/test_naipchesapeake.py delete mode 100644 tests/trainers/test_resisc45.py delete mode 100644 torchgeo/trainers/chesapeake.py delete mode 100644 torchgeo/trainers/landcoverai.py delete mode 100644 torchgeo/trainers/naipchesapeake.py delete mode 100644 torchgeo/trainers/resisc45.py diff --git a/conf/task_defaults/eurosat.yaml b/conf/task_defaults/eurosat.yaml index 9161c2c0cae..2f288b56bdd 100644 --- a/conf/task_defaults/eurosat.yaml +++ b/conf/task_defaults/eurosat.yaml @@ -7,7 +7,7 @@ experiment: learning_rate_schedule_patience: 6 weights: "random" in_channels: 13 - num_classes: 10 + num_classes: 2 datamodule: root_dir: "tests/data/eurosat" batch_size: 1 diff --git a/conf/task_defaults/resisc45.yaml b/conf/task_defaults/resisc45.yaml index e95efe9af89..8e46b26ecca 100644 --- a/conf/task_defaults/resisc45.yaml +++ b/conf/task_defaults/resisc45.yaml @@ -7,7 +7,7 @@ experiment: learning_rate_schedule_patience: 6 weights: "random" in_channels: 3 - num_classes: 45 + num_classes: 3 datamodule: root_dir: "tests/data/resisc45" batch_size: 1 diff --git a/conf/task_defaults/ucmerced.yaml b/conf/task_defaults/ucmerced.yaml index 31f7dba2960..04de488a7e3 100644 --- a/conf/task_defaults/ucmerced.yaml +++ b/conf/task_defaults/ucmerced.yaml @@ -7,7 +7,7 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 - num_classes: 21 + num_classes: 2 datamodule: root_dir: "tests/data/ucmerced" batch_size: 1 diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index f3df706a8b4..6979999de03 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -270,3 +270,4 @@ Collation Functions .. autofunction:: stack_samples .. autofunction:: concat_samples .. autofunction:: merge_samples +.. autofunction:: unbind_samples diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 8803eb3d3bb..a935ccea563 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -31,6 +31,7 @@ merge_samples, percentile_normalization, stack_samples, + unbind_samples, working_dir, ) @@ -457,7 +458,7 @@ def samples(self) -> List[Dict[str, Any]]: }, ] - def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size( # type: ignore[attr-defined] [2, 3] @@ -468,6 +469,13 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: ) assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)] + new_samples = unbind_samples(sample) + for i in range(2): + assert torch.allclose( # type: ignore[attr-defined] + samples[i]["image"], new_samples[i]["image"] + ) + assert samples[i]["crs"] == new_samples[i]["crs"] + def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined] @@ -500,7 +508,7 @@ def samples(self) -> List[Dict[str, Any]]: }, ] - def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size( # type: ignore[attr-defined] [1, 3] @@ -515,6 +523,16 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: assert sample["crs1"] == [CRS.from_epsg(2000)] assert sample["crs2"] == [CRS.from_epsg(2001)] + new_samples = unbind_samples(sample) + assert torch.allclose( # type: ignore[attr-defined] + samples[0]["image"], new_samples[0]["image"] + ) + assert samples[0]["crs1"] == new_samples[0]["crs1"] + assert torch.allclose( # type: ignore[attr-defined] + samples[1]["mask"], new_samples[0]["mask"] + ) + assert samples[1]["crs2"] == new_samples[0]["crs2"] + def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined] diff --git a/tests/trainers/test_chesapeake.py b/tests/trainers/test_chesapeake.py deleted file mode 100644 index 377d4a85e35..00000000000 --- a/tests/trainers/test_chesapeake.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datamodules import ChesapeakeCVPRDataModule -from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -class TestChesapeakeCVPRSegmentationTask: - @pytest.fixture(scope="class", params=[5, 7]) - def class_set(self, request: SubRequest) -> int: - return cast(int, request.param) - - @pytest.fixture(scope="class") - def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=class_set, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self, class_set: int) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", f"chesapeake_cvpr_{class_set}.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> ChesapeakeCVPRSegmentationTask: - task = ChesapeakeCVPRSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: ChesapeakeCVPRSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index f2c091c75ee..4194982da6c 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -50,6 +50,23 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer.fit(model=model, datamodule=datamodule) trainer.test(model=model, datamodule=datamodule) + def test_no_logger(self) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", "ucmerced.yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = UCMercedDataModule(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = ClassificationTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { @@ -120,6 +137,25 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer.fit(model=model, datamodule=datamodule) trainer.test(model=model, datamodule=datamodule) + def test_no_logger(self) -> None: + conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "bigearthnet_s1.yaml") + ) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = BigEarthNetDataModule(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = MultiLabelClassificationTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py deleted file mode 100644 index d3e70dfb098..00000000000 --- a/tests/trainers/test_landcoverai.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datamodules import LandCoverAIDataModule -from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -class TestLandCoverAISegmentationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> LandCoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandCoverAIDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "landcoverai.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["verbose"] = True - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> LandCoverAISegmentationTask: - task = LandCoverAISegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_training( - self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: LandCoverAIDataModule, task: LandCoverAISegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) diff --git a/tests/trainers/test_naipchesapeake.py b/tests/trainers/test_naipchesapeake.py deleted file mode 100644 index 37d94cb0ed8..00000000000 --- a/tests/trainers/test_naipchesapeake.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator, cast - -import pytest -from _pytest.monkeypatch import MonkeyPatch -from omegaconf import OmegaConf - -from torchgeo.datamodules import NAIPChesapeakeDataModule -from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask - -from .test_utils import FakeTrainer, mocked_log - - -class TestNAIPChesapeakeSegmentationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "naipchesapeake.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> NAIPChesapeakeSegmentationTask: - task = NAIPChesapeakeSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_validation( - self, datamodule: NAIPChesapeakeDataModule, task: NAIPChesapeakeSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 574c930e7d6..3b3db221257 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -35,6 +35,23 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer.fit(model=model, datamodule=datamodule) trainer.test(model=model, datamodule=datamodule) + def test_no_logger(self) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", "cyclone.yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = CycloneDataModule(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = RegressionTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): diff --git a/tests/trainers/test_resisc45.py b/tests/trainers/test_resisc45.py deleted file mode 100644 index 1eec36e2fee..00000000000 --- a/tests/trainers/test_resisc45.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from typing import Any, Dict, Generator - -import pytest -from _pytest.monkeypatch import MonkeyPatch - -from torchgeo.datamodules import RESISC45DataModule -from torchgeo.trainers.resisc45 import RESISC45ClassificationTask - -from .test_utils import FakeTrainer, mocked_log - - -class TestRESISC45ClassificationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> RESISC45DataModule: - root = os.path.join("tests", "data", "resisc45") - batch_size = 2 - num_workers = 0 - dm = RESISC45DataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture() - def config(self) -> Dict[str, Any]: - task_args: Dict[str, Any] = {} - task_args["classification_model"] = "resnet18" - task_args["learning_rate"] = 3e-4 - task_args["learning_rate_schedule_patience"] = 6 - task_args["in_channels"] = 3 - task_args["loss"] = "ce" - task_args["num_classes"] = 45 - task_args["weights"] = "random" - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> RESISC45ClassificationTask: - task = RESISC45ClassificationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_training( - self, datamodule: RESISC45DataModule, task: RESISC45ClassificationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 1d13299334e..e29f0d99536 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -53,6 +53,23 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer.fit(model=model, datamodule=datamodule) trainer.test(model=model, datamodule=datamodule) + def test_no_logger(self) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", "landcoverai.yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = LandCoverAIDataModule(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = SemanticSegmentationTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index d6daddc596c..de6d07f054b 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -3,7 +3,6 @@ import os from pathlib import Path -from typing import Any import pytest import torch @@ -17,26 +16,6 @@ ) -class FakeExperiment(object): - def add_figure(self, *args: Any, **kwargs: Any) -> None: - pass - - -class FakeLogger(object): - def __init__(self) -> None: - self.experiment = FakeExperiment() - - -class FakeTrainer(object): - def __init__(self) -> None: - self.logger = FakeLogger() - self.global_step = 1 - - -def mocked_log(*args: Any, **kwargs: Any) -> None: - pass - - def test_extract_encoder_unsupported_model(tmp_path: Path) -> None: checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} path = os.path.join(str(tmp_path), "dummy.ckpt") diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 11c2e4ed9ab..890e94fc2cc 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -176,3 +177,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.BigEarthNet.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 4d6e4a7cdb8..a4c6f98e810 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch import Generator # type: ignore[attr-defined] from torch.utils.data import DataLoader, random_split @@ -121,3 +122,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.COWC.plot`.""" + return self.val_dataset.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 5db89a07379..933433f0ceb 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch import Generator # type: ignore[attr-defined] @@ -149,3 +150,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.ETCI2021.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 72708e07019..8a4281eccc8 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -146,3 +147,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index b0f23182a27..8feb83bb7b3 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -5,6 +5,8 @@ from typing import Any, Dict, Optional +import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader @@ -36,6 +38,45 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers + def on_after_batch_transfer( + self, batch: Dict[str, Any], batch_idx: int + ) -> Dict[str, Any]: + """Apply batch augmentations after batch is transferred to the device. + + Args: + batch: mini-batch of data + batch_idx: batch index + + Returns: + augmented mini-batch + """ + if ( + hasattr(self, "trainer") + and hasattr(self.trainer, "training") + and self.trainer.training # type: ignore[union-attr] + ): + # Kornia expects masks to be floats with a channel dimension + x = batch["image"] + y = batch["mask"].float().unsqueeze(1) + + train_augmentations = K.AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter( + p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1 + ), + data_keys=["input", "mask"], + ) + x, y = train_augmentations(x, y) + + # torchmetrics expects masks to be longs without a channel dimension + batch["image"] = x + batch["mask"] = y.squeeze(1).long() + + return batch + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -57,7 +98,7 @@ def prepare_data(self) -> None: This method is only called once per run. """ - _ = LandCoverAI(self.root_dir, download=False, checksum=False) + LandCoverAI(self.root_dir, download=False, checksum=False) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. @@ -120,3 +161,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.LandCoverAI.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 844ee0968a9..f3e076399fa 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -5,6 +5,8 @@ from typing import Any, Dict, Optional +import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -48,6 +50,42 @@ def __init__( self.norm = Normalize(self.band_means, self.band_stds) + def on_after_batch_transfer( + self, batch: Dict[str, Any], batch_idx: int + ) -> Dict[str, Any]: + """Apply batch augmentations after batch is transferred to the device. + + Args: + batch: mini-batch of data + batch_idx: batch index + + Returns: + augmented mini-batch + """ + if ( + hasattr(self, "trainer") + and hasattr(self.trainer, "training") + and self.trainer.training # type: ignore[union-attr] + ): + x = batch["image"] + + train_augmentations = K.AugmentationSequential( + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.RandomErasing(p=0.1), + K.ColorJitter( + p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1 + ), + data_keys=["input"], + ) + x = train_augmentations(x) + + batch["image"] = x + + return batch + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -121,3 +159,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.RESISC45.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 69cd9773384..77dc718f7b2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch import torchvision @@ -123,3 +124,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.UCMerced.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 7e3cf7811c2..0571ea7651c 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -67,7 +67,13 @@ from .so2sat import So2Sat from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 from .ucmerced import UCMerced -from .utils import BoundingBox, concat_samples, merge_samples, stack_samples +from .utils import ( + BoundingBox, + concat_samples, + merge_samples, + stack_samples, + unbind_samples, +) from .vaihingen import Vaihingen2D from .xview import XView2 from .zuericrop import ZueriCrop @@ -150,6 +156,7 @@ "concat_samples", "merge_samples", "stack_samples", + "unbind_samples", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 45ef7da2c17..9e7d28d7dbe 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -200,7 +200,7 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index b786e9b996a..7a5349c6a37 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -46,6 +46,7 @@ "stack_samples", "concat_samples", "merge_samples", + "unbind_samples", "rasterio_loader", "sort_sentinel2_bands", "draw_semantic_segmentation_masks", @@ -444,6 +445,26 @@ def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List return collated +def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: + """Convert a dictionary of lists to a list of dictionaries. + + Args: + sample: a dictionary of lists + + Returns: + a list of dictionaries + + .. versionadded:: 0.2 + """ + uncollated: List[Dict[Any, Any]] = [ + {} for _ in range(max(map(len, sample.values()))) + ] + for key, values in sample.items(): + for i, value in enumerate(values): + uncollated[i][key] = value + return uncollated + + def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: """Stack a list of samples along a new axis. @@ -514,6 +535,26 @@ def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: return collated +def unbind_samples(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: + """Reverse of :func:`stack_samples`. + + Useful for turning a mini-batch of samples into a list of samples. These individual + samples can then be plotted using a dataset's ``plot`` method. + + Args: + sample: a mini-batch of samples + + Returns: + list of samples + + .. versionadded:: 0.2 + """ + for key, values in sample.items(): + if isinstance(values, Tensor): + sample[key] = torch.unbind(values) + return _dict_list_to_list_dict(sample) + + def rasterio_loader(path: str) -> "np.typing.NDArray[np.int_]": """Load an image file using rasterio. diff --git a/torchgeo/trainers/chesapeake.py b/torchgeo/trainers/chesapeake.py deleted file mode 100644 index f1c78ef872a..00000000000 --- a/torchgeo/trainers/chesapeake.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Segmentation tasks.""" - -from typing import Any, Dict - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] - -from ..datasets import Chesapeake7 -from .segmentation import SemanticSegmentationTask - -# TODO: move the color maps to a dataset object -CMAP_7 = matplotlib.colors.ListedColormap( - [np.array(Chesapeake7.cmap[i]) / 255.0 for i in range(7)] -) -CMAP_5 = matplotlib.colors.ListedColormap( - np.array( - [ - (0, 0, 0, 0), - (0, 197, 255, 255), - (38, 115, 0, 255), - (163, 255, 115, 255), - (156, 156, 156, 255), - ] - ) - / 255.0 -) - - -# TODO: move this functionality into SemanticSegmentationTask and remove this class -class ChesapeakeCVPRSegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - - .. deprecated:: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10: - cmap = None - if self.hparams["num_classes"] == 5: - cmap = CMAP_5 - else: - cmap = CMAP_7 - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img[:, :, :3]) - axs[0].axis("off") - axs[1].imshow( - mask, - vmin=0, - vmax=self.hparams["num_classes"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[1].axis("off") - axs[2].imshow( - pred, - vmin=0, - vmax=self.hparams["num_classes"] - 1, - cmap=cmap, - interpolation="none", - ) - axs[2].axis("off") - plt.tight_layout() - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index cf8d3e290b7..5b69f10b14d 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -16,6 +16,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import Accuracy, FBeta, IoU, MetricCollection +from ..datasets.utils import unbind_samples from . import utils # https://github.com/pytorch/pytorch/issues/60979 @@ -179,6 +180,21 @@ def validation_step( # type: ignore[override] self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) + if batch_idx < 10: + try: + datamodule = self.trainer.datamodule # type: ignore[attr-defined] + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + except AttributeError: + pass + def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. @@ -332,6 +348,21 @@ def validation_step( # type: ignore[override] self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) + if batch_idx < 10: + try: + datamodule = self.trainer.datamodule # type: ignore[attr-defined] + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + except AttributeError: + pass + def test_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py deleted file mode 100644 index 7c9ac27856a..00000000000 --- a/torchgeo/trainers/landcoverai.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Segmentation tasks.""" - -from typing import Any, Dict, cast - -import kornia.augmentation as K -import matplotlib.pyplot as plt -import numpy as np -import torch -from torch import Tensor -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] - -from .segmentation import SemanticSegmentationTask - - -# TODO: move this functionality into SemanticSegmentationTask and remove this class -class LandCoverAISegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the Landcover.AI Dataset. - - .. deprecated:: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - # TODO: move this to LandCoverAIDataModule - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["input", "mask"], - ) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["mask"].float().unsqueeze(1) - with torch.no_grad(): - x, y = self.train_augmentations(x, y) - y = y.squeeze(1).long() - - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Logs the first 10 validation samples to tensorboard as images with 3 subplots - showing the image, mask, and predictions. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10 and self.hparams["verbose"]: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - x[0].cpu().numpy(), 0, 3 - ) - mask = y[0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=5) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=5) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() - - def test_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Test step identical to the validation step. - - Args: - batch: Current batch - batch_idx: Index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat_hard, y) diff --git a/torchgeo/trainers/naipchesapeake.py b/torchgeo/trainers/naipchesapeake.py deleted file mode 100644 index 4139b00ce67..00000000000 --- a/torchgeo/trainers/naipchesapeake.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Segmentation tasks.""" - -from typing import Any, Dict - -import matplotlib.pyplot as plt -import numpy as np -from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] - -from .segmentation import SemanticSegmentationTask - - -# TODO: move this functionality into SemanticSegmentationTask and remove this class -class NAIPChesapeakeSegmentationTask(SemanticSegmentationTask): - """LightningModule for training models on the NAIP and Chesapeake datasets. - - .. deprecated:: 0.1 - Use :class:`SemanticSegmentationTask` instead. - """ - - def validation_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> None: - """Validation step - reports average accuracy and average IoU. - - Args: - batch: current batch - batch_idx: index of current batch - """ - x = batch["image"] - y = batch["mask"] - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("val_loss", loss) - self.val_metrics(y_hat_hard, y) - - if batch_idx < 10: - # Render the image, ground truth mask, and predicted mask for the first - # image in the batch - img = np.rollaxis( # convert image to channels last format - batch["image"][0].cpu().numpy(), 0, 3 - ) - mask = batch["mask"][0].cpu().numpy() - pred = y_hat_hard[0].cpu().numpy() - fig, axs = plt.subplots(1, 3, figsize=(12, 4)) - axs[0].imshow(img) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=4) - axs[1].axis("off") - axs[2].imshow(pred, vmin=0, vmax=4) - axs[2].axis("off") - - # the SummaryWriter is a tensorboard object, see: - # https://pytorch.org/docs/stable/tensorboard.html# - summary_writer: SummaryWriter = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - - plt.close() diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 8e6422d02c9..6393ceda2ff 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -15,6 +15,8 @@ from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchvision import models +from ..datasets.utils import unbind_samples + # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 Conv2d.__module__ = "nn.Conv2d" @@ -107,6 +109,21 @@ def validation_step( # type: ignore[override] self.log("val_loss", loss) self.val_metrics(y_hat, y) + if batch_idx < 10: + try: + datamodule = self.trainer.datamodule # type: ignore[attr-defined] + batch["prediction"] = y_hat + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + except AttributeError: + pass + def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/torchgeo/trainers/resisc45.py b/torchgeo/trainers/resisc45.py deleted file mode 100644 index 27c1ec32d93..00000000000 --- a/torchgeo/trainers/resisc45.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Custom trainer for the RESISC45 dataset.""" - -from typing import Any, Dict, cast - -import kornia.augmentation as K -import torch -from torch import Tensor - -from .classification import ClassificationTask - - -# TODO: move this functionality into ClassificationTask and remove this class -class RESISC45ClassificationTask(ClassificationTask): - """LightningModule for training on RESISC45 with data augmentation. - - .. deprecated:: 0.1 - Use :class:`ClassificationTask` instead. - """ - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - classification_model: Name of the classification model use - loss: Name of the loss function - weights: Either "random", "imagenet_only", "imagenet_and_random", or - "random_rgb" - """ - super().__init__(**kwargs) - - self.train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.RandomErasing(p=0.1), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["input"], - ) - - def training_step( # type: ignore[override] - self, batch: Dict[str, Any], batch_idx: int - ) -> Tensor: - """Training step - reports average accuracy and average IoU. - - Args: - batch: Current batch - batch_idx: Index of current batch - - Returns: - training loss - """ - x = batch["image"] - y = batch["label"] - with torch.no_grad(): - x = self.train_augmentations(x) - y_hat = self.forward(x) - y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat_hard, y) - - return cast(Tensor, loss) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index ccb7c6d6b28..1da0cf05d76 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy, IoU, MetricCollection +from ..datasets.utils import unbind_samples from ..models import FCN # https://github.com/pytorch/pytorch/issues/60979 @@ -173,6 +174,21 @@ def validation_step( # type: ignore[override] self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) + if batch_idx < 10: + try: + datamodule = self.trainer.datamodule # type: ignore[attr-defined] + batch["prediction"] = y_hat_hard + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + except AttributeError: + pass + def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/train.py b/train.py index eac21a07d7f..057607a56b2 100755 --- a/train.py +++ b/train.py @@ -35,10 +35,6 @@ RegressionTask, SemanticSegmentationTask, ) -from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask -from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask -from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask -from torchgeo.trainers.resisc45 import RESISC45ClassificationTask TASK_TO_MODULES_MAPPING: Dict[ str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] @@ -47,18 +43,18 @@ "bigearthnet_s1": (MultiLabelClassificationTask, BigEarthNetDataModule), "bigearthnet_s2": (MultiLabelClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), - "chesapeake_cvpr_5": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), - "chesapeake_cvpr_7": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), - "chesapeake_cvpr_prior": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_5": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_7": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_prior": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), "cowc_counting": (RegressionTask, COWCCountingDataModule), "cyclone": (RegressionTask, CycloneDataModule), "eurosat": (ClassificationTask, EuroSATDataModule), "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), - "landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule), - "naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule), + "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), + "naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), "oscd_all": (SemanticSegmentationTask, OSCDDataModule), "oscd_rgb": (SemanticSegmentationTask, OSCDDataModule), - "resisc45": (RESISC45ClassificationTask, RESISC45DataModule), + "resisc45": (ClassificationTask, RESISC45DataModule), "sen12ms_all": (SemanticSegmentationTask, SEN12MSDataModule), "sen12ms_s1": (SemanticSegmentationTask, SEN12MSDataModule), "sen12ms_s2_all": (SemanticSegmentationTask, SEN12MSDataModule),