Skip to content

Commit

Permalink
Use datamodule plot instead of dataset plot
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 31, 2021
1 parent d8345f1 commit 6bf2b19
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 16 deletions.
5 changes: 5 additions & 0 deletions torchgeo/datamodules/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -176,3 +177,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.BigEarthNet.plot`."""
return self.val_dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch import Generator # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -121,3 +122,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.COWC.plot`."""
return self.val_dataset.dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Generator # type: ignore[attr-defined]
Expand Down Expand Up @@ -149,3 +150,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.ETCI2021.plot`."""
return self.val_dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -146,3 +147,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.EuroSAT.plot`."""
return self.val_dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, Optional

import kornia.augmentation as K
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -158,3 +159,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.LandCoverAI.plot`."""
return self.val_dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, List, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Tensor
Expand Down Expand Up @@ -138,3 +139,7 @@ def test_dataloader(self) -> DataLoader[Any]:
shuffle=False,
collate_fn=collate_fn,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`."""
return self.val_dataset.dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, Optional

import kornia.augmentation as K
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -156,3 +157,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.RESISC45.plot`."""
return self.val_dataset.plot(*args, **kwargs)
22 changes: 10 additions & 12 deletions torchgeo/datamodules/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -185,18 +186,6 @@ def setup(self, stage: Optional[str] = None) -> None:
So2Sat, temp_train + self.val_dataset + self.test_dataset
)

# So2Sat dataset doesn't know how to plot any band set other than "all"
# TODO: move band selection to the Dataset level so that plot knows about it
if self.bands == "rgb":
# delattr doesn't work for some reason
# https://stackoverflow.com/a/1684219/5828163
def noattr() -> None:
raise AttributeError

self.val_dataset.plot = ( # type: ignore[assignment]
lambda *args, **kwargs: noattr()
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Expand Down Expand Up @@ -235,3 +224,12 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`."""
# So2Sat dataset doesn't know how to plot any band set other than "all"
# TODO: move band selection to the Dataset level so that plot knows about it
if self.bands == "rgb":
raise AttributeError

return self.val_dataset.plot(*args, **kwargs)
5 changes: 5 additions & 0 deletions torchgeo/datamodules/ucmerced.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torchvision
Expand Down Expand Up @@ -123,3 +124,7 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.UCMerced.plot`."""
return self.val_dataset.plot(*args, **kwargs)
4 changes: 2 additions & 2 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def validation_step( # type: ignore[override]
try:
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
sample = unbind_samples(batch)[0]
fig = datamodule.val_dataset.plot(sample)
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
Expand Down Expand Up @@ -349,7 +349,7 @@ def validation_step( # type: ignore[override]
try:
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
sample = unbind_samples(batch)[0]
fig = datamodule.val_dataset.plot(sample)
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def validation_step( # type: ignore[override]
try:
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
sample = unbind_samples(batch)[0]
fig = datamodule.val_dataset.plot(sample)
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def validation_step( # type: ignore[override]
try:
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
sample = unbind_samples(batch)[0]
fig = datamodule.val_dataset.plot(sample)
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
Expand Down

0 comments on commit 6bf2b19

Please sign in to comment.