Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot functions to datasets #251

Merged
merged 16 commits into from
Nov 19, 2021
11 changes: 11 additions & 0 deletions tests/datasets/test_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -90,6 +91,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match=err):
EuroSAT(str(tmp_path))

def test_plot(self, dataset: EuroSAT) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()


class TestEuroSATDataModule:
@pytest.fixture(scope="class")
Expand Down
11 changes: 11 additions & 0 deletions tests/datasets/test_landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -68,6 +69,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
LandCoverAI(str(tmp_path))

def test_plot(self, dataset: LandCoverAI) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()


class TestLandCoverAIDataModule:
@pytest.fixture(scope="class")
Expand Down
11 changes: 11 additions & 0 deletions tests/datasets/test_resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -91,6 +92,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match=err):
RESISC45(str(tmp_path))

def test_plot(self, dataset: RESISC45) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()


class TestRESISC45DataModule:
@pytest.fixture(scope="class")
Expand Down
13 changes: 5 additions & 8 deletions torchgeo/datasets/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Generator, Tensor # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -265,7 +264,7 @@ def plot(
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
Expand All @@ -280,18 +279,16 @@ def plot(
vh = np.rollaxis(sample["image"][3:].numpy(), 0, 3)
water_mask = sample["mask"][0].numpy()

showing_flood_mask = False
showing_predictions = False
showing_flood_mask = sample["mask"].shape[0] > 1
showing_predictions = "prediction" in sample
num_panels = 3
if sample["mask"].shape[0] > 1:
if showing_flood_mask:
flood_mask = sample["mask"][1].numpy()
num_panels += 1
showing_flood_mask = True

if "prediction" in sample:
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1
showing_predictions = True

fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 3))
axs[0].imshow(vv)
Expand Down
56 changes: 55 additions & 1 deletion torchgeo/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""EuroSAT dataset."""

import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
Expand Down Expand Up @@ -72,6 +74,18 @@ class EuroSAT(VisionClassificationDataset):
"val": "95de90f2aa998f70a3b2416bfe0687b4",
"test": "7ae5ab94471417b6e315763121e67c5f",
}
classes = [
"Industrial Buildings",
"Residential Buildings",
"Annual Crop",
"Permanent Crop",
"River",
"Sea and Lake",
"Herbaceous Vegetation",
"Highway",
"Pasture",
"Forest",
]

def __init__(
self,
Expand Down Expand Up @@ -174,6 +188,46 @@ def _extract(self) -> None:
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Returns:
a matplotlib Figure with the rendered sample
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
"""
image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy()
image = np.clip(image / 3000, 0, 1)

label = cast(int, sample["label"].item())
label_class = self.classes[label]

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = cast(int, sample["prediction"].item())
prediction_class = self.classes[prediction]

fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.axis("off")
if show_titles:
title = f"Label: {label_class}"
if showing_predictions:
title += f"\nPrediction: {prediction_class}"
ax.set_title(title)

if suptitle is not None:
plt.suptitle(suptitle)
return fig


class EuroSATDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the EuroSAT dataset.
Expand Down
58 changes: 58 additions & 0 deletions torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from functools import lru_cache
from typing import Any, Callable, Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib.colors import ListedColormap
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -67,6 +69,16 @@ class LandCoverAI(VisionDataset):
filename = "landcover.ai.v1.zip"
md5 = "3268c89070e8734b4e91d531c0617e03"
sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1"
classes = ["Background", "Building", "Woodland", "Water", "Road"]
cmap = ListedColormap(
[
[0.63921569, 1.0, 0.45098039],
[0.61176471, 0.61176471, 0.61176471],
[0.14901961, 0.45098039, 0.0],
[0.0, 0.77254902, 1.0],
[0.0, 0.0, 0.0],
]
)

def __init__(
self,
Expand Down Expand Up @@ -207,6 +219,52 @@ def _download(self) -> None:
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Returns:
a matplotlib Figure with the rendered sample
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
mask = sample["mask"].numpy()

num_panels = 2
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1

fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
axs[0].imshow(image)
axs[0].axis("off")
axs[1].imshow(mask, vmin=0, vmax=4, cmap=self.cmap, interpolation="none")
axs[1].axis("off")
if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")

if showing_predictions:
axs[2].imshow(
predictions, vmin=0, vmax=4, cmap=self.cmap, interpolation="none"
)
axs[2].axis("off")
if show_titles:
axs[2].set_title("Predictions")

if suptitle is not None:
plt.suptitle(suptitle)
return fig


class LandCoverAIDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LandCover.ai dataset.
Expand Down
89 changes: 88 additions & 1 deletion torchgeo/datasets/resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""RESISC45 dataset."""

import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
Expand Down Expand Up @@ -113,6 +115,53 @@ class RESISC45(VisionClassificationDataset):
"val": "a0770cee4c5ca20b8c32bbd61e114805",
"test": "3dda9e4988b47eb1de9f07993653eb08",
}
classes = [
"airplane",
"airport",
"baseball_diamond",
"basketball_court",
"beach",
"bridge",
"chaparral",
"church",
"circular_farmland",
"cloud",
"commercial_area",
"dense_residential",
"desert",
"forest",
"freeway",
"golf_course",
"ground_track_field",
"harbor",
"industrial_area",
"intersection",
"island",
"lake",
"meadow",
"medium_residential",
"mobile_home_park",
"mountain",
"overpass",
"palace",
"parking_lot",
"railway",
"railway_station",
"rectangular_farmland",
"river",
"roundabout",
"runway",
"sea_ice",
"ship",
"snowberg",
"sparse_residential",
"stadium",
"storage_tank",
"tennis_court",
"terrace",
"thermal_power_station",
"wetland",
]

def __init__(
self,
Expand Down Expand Up @@ -200,6 +249,44 @@ def _extract(self) -> None:
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Returns:
a matplotlib Figure with the rendered sample
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
label = cast(int, sample["label"].item())
label_class = self.classes[label]

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = cast(int, sample["prediction"].item())
prediction_class = self.classes[prediction]

fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(image)
ax.axis("off")
if show_titles:
title = f"Label: {label_class}"
if showing_predictions:
title += f"\nPrediction: {prediction_class}"
ax.set_title(title)

if suptitle is not None:
plt.suptitle(suptitle)
return fig


class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Expand Down
Loading