From 6bf2b19a8536a3e9660ecf432c2c2a4e97812334 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 31 Dec 2021 14:41:59 -0600 Subject: [PATCH] Use datamodule plot instead of dataset plot --- torchgeo/datamodules/bigearthnet.py | 5 +++++ torchgeo/datamodules/cowc.py | 5 +++++ torchgeo/datamodules/etci2021.py | 5 +++++ torchgeo/datamodules/eurosat.py | 5 +++++ torchgeo/datamodules/landcoverai.py | 5 +++++ torchgeo/datamodules/nasa_marine_debris.py | 5 +++++ torchgeo/datamodules/resisc45.py | 5 +++++ torchgeo/datamodules/so2sat.py | 22 ++++++++++------------ torchgeo/datamodules/ucmerced.py | 5 +++++ torchgeo/trainers/classification.py | 4 ++-- torchgeo/trainers/regression.py | 2 +- torchgeo/trainers/segmentation.py | 2 +- 12 files changed, 54 insertions(+), 16 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 11c2e4ed9ab..890e94fc2cc 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -176,3 +177,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.BigEarthNet.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 4d6e4a7cdb8..a4c6f98e810 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch import Generator # type: ignore[attr-defined] from torch.utils.data import DataLoader, random_split @@ -121,3 +122,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.COWC.plot`.""" + return self.val_dataset.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 5db89a07379..933433f0ceb 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch import Generator # type: ignore[attr-defined] @@ -149,3 +150,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.ETCI2021.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 72708e07019..8a4281eccc8 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -146,3 +147,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 5cf3812a45c..b189c98fe95 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Optional import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader @@ -158,3 +159,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.LandCoverAI.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index e6337e9fb6a..99fcb49b41c 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch import Tensor @@ -138,3 +139,7 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, collate_fn=collate_fn, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`.""" + return self.val_dataset.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 0892f1ddaaf..2de467b7b97 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Optional import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -156,3 +157,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.RESISC45.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 85992f49725..e74c4457cd3 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from torch.utils.data import DataLoader @@ -185,18 +186,6 @@ def setup(self, stage: Optional[str] = None) -> None: So2Sat, temp_train + self.val_dataset + self.test_dataset ) - # So2Sat dataset doesn't know how to plot any band set other than "all" - # TODO: move band selection to the Dataset level so that plot knows about it - if self.bands == "rgb": - # delattr doesn't work for some reason - # https://stackoverflow.com/a/1684219/5828163 - def noattr() -> None: - raise AttributeError - - self.val_dataset.plot = ( # type: ignore[assignment] - lambda *args, **kwargs: noattr() - ) - def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -235,3 +224,12 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`.""" + # So2Sat dataset doesn't know how to plot any band set other than "all" + # TODO: move band selection to the Dataset level so that plot knows about it + if self.bands == "rgb": + raise AttributeError + + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 69cd9773384..77dc718f7b2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch import torchvision @@ -123,3 +124,7 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.UCMerced.plot`.""" + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 03c5232cd63..a0a5a49bac7 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -184,7 +184,7 @@ def validation_step( # type: ignore[override] try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] sample = unbind_samples(batch)[0] - fig = datamodule.val_dataset.plot(sample) + fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step @@ -349,7 +349,7 @@ def validation_step( # type: ignore[override] try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] sample = unbind_samples(batch)[0] - fig = datamodule.val_dataset.plot(sample) + fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 2cbff6e2f98..2a71e9851b2 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -113,7 +113,7 @@ def validation_step( # type: ignore[override] try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] sample = unbind_samples(batch)[0] - fig = datamodule.val_dataset.plot(sample) + fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1b71fb14f40..0906c6b5234 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -178,7 +178,7 @@ def validation_step( # type: ignore[override] try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] sample = unbind_samples(batch)[0] - fig = datamodule.val_dataset.plot(sample) + fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step