From c3537086925a2b74ca4359d21be7add88e0ddb0d Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 07:02:36 +0000 Subject: [PATCH 01/16] Adding plot function for EuroSAT --- tests/datasets/test_eurosat.py | 11 ++++++ torchgeo/datasets/eurosat.py | 64 ++++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 4e204b93c7f..008195bb72a 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -90,6 +91,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): EuroSAT(str(tmp_path)) + def test_plot(self, dataset: EuroSAT) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestEuroSATDataModule: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 2a77e8c0ef8..7685412a515 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,16 +4,24 @@ """EuroSAT dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch +from matplotlib.figure import Figure from torch import Tensor from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset -from .utils import check_integrity, download_url, extract_archive, rasterio_loader +from .utils import ( + check_integrity, + download_url, + extract_archive, + rasterio_loader, +) class EuroSAT(VisionClassificationDataset): @@ -72,6 +80,18 @@ class EuroSAT(VisionClassificationDataset): "val": "95de90f2aa998f70a3b2416bfe0687b4", "test": "7ae5ab94471417b6e315763121e67c5f", } + classes = [ + "Industrial Buildings", + "Residential Buildings", + "Annual Crop", + "Permanent Crop", + "River", + "Sea and Lake", + "Herbaceous Vegetation", + "Highway", + "Pasture", + "Forest", + ] def __init__( self, @@ -174,6 +194,46 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy() + image = np.clip(image / 3000, 0, 1) + + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class EuroSATDataModule(pl.LightningDataModule): """LightningDataModule implementation for the EuroSAT dataset. From 18e662d21c03b37350d99ed80eab65d1e4f35fca Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 07:40:02 +0000 Subject: [PATCH 02/16] Added LandCover.ai plot --- tests/datasets/test_landcoverai.py | 11 ++++++ torchgeo/datasets/landcoverai.py | 58 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index f197077f225..8a6942e5800 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -68,6 +69,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LandCoverAI(str(tmp_path)) + def test_plot(self, dataset: LandCoverAI) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + class TestLandCoverAIDataModule: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 771e87f5c9a..f98e43f974f 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -8,9 +8,11 @@ from functools import lru_cache from typing import Any, Callable, Dict, Optional +import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch +from matplotlib.colors import ListedColormap from PIL import Image from torch import Tensor from torch.utils.data import DataLoader @@ -67,6 +69,16 @@ class LandCoverAI(VisionDataset): filename = "landcover.ai.v1.zip" md5 = "3268c89070e8734b4e91d531c0617e03" sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1" + classes = ["Background", "Building", "Woodland", "Water", "Road"] + cmap = ListedColormap( + [ + [0.63921569, 1.0, 0.45098039], + [0.61176471, 0.61176471, 0.61176471], + [0.14901961, 0.45098039, 0.0], + [0.0, 0.77254902, 1.0], + [0.0, 0.0, 0.0], + ] + ) def __init__( self, @@ -207,6 +219,52 @@ def _download(self) -> None: assert hashlib.sha256(split).hexdigest() == self.sha256 exec(split) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + mask = sample["mask"].numpy() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4, cmap=self.cmap, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow( + predictions, vmin=0, vmax=4, cmap=self.cmap, interpolation="none" + ) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class LandCoverAIDataModule(pl.LightningDataModule): """LightningDataModule implementation for the LandCover.ai dataset. From 3b4709a944d6f917d5ac1babd1eebc7f18ba721c Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 07:40:27 +0000 Subject: [PATCH 03/16] Cleaning up --- torchgeo/datasets/etci2021.py | 13 +++++-------- torchgeo/datasets/eurosat.py | 10 ++-------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 56d25622ca8..bb10da22bff 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -11,7 +11,6 @@ import numpy as np import pytorch_lightning as pl import torch -from matplotlib.figure import Figure from PIL import Image from torch import Generator, Tensor # type: ignore[attr-defined] from torch.utils.data import DataLoader, random_split @@ -265,7 +264,7 @@ def plot( sample: Dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, - ) -> Figure: + ) -> plt.Figure: """Plot a sample from the dataset. Args: @@ -280,18 +279,16 @@ def plot( vh = np.rollaxis(sample["image"][3:].numpy(), 0, 3) water_mask = sample["mask"][0].numpy() - showing_flood_mask = False - showing_predictions = False + showing_flood_mask = sample["mask"].shape[0] > 1 + showing_predictions = "prediction" in sample num_panels = 3 - if sample["mask"].shape[0] > 1: + if showing_flood_mask: flood_mask = sample["mask"][1].numpy() num_panels += 1 - showing_flood_mask = True - if "prediction" in sample: + if showing_predictions: predictions = sample["prediction"].numpy() num_panels += 1 - showing_predictions = True fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 3)) axs[0].imshow(vv) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 7685412a515..b9c294fc652 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -10,18 +10,12 @@ import numpy as np import pytorch_lightning as pl import torch -from matplotlib.figure import Figure from torch import Tensor from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset -from .utils import ( - check_integrity, - download_url, - extract_archive, - rasterio_loader, -) +from .utils import check_integrity, download_url, extract_archive, rasterio_loader class EuroSAT(VisionClassificationDataset): @@ -199,7 +193,7 @@ def plot( sample: Dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, - ) -> Figure: + ) -> plt.Figure: """Plot a sample from the dataset. Args: From e6bfe32a76e9114e7b767fd00395319da26b24dc Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 07:56:59 +0000 Subject: [PATCH 04/16] Added method for percentile normalization --- torchgeo/datasets/utils.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index a888805c1a1..508c2bbb828 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -12,7 +12,7 @@ import tarfile import zipfile from datetime import datetime, timedelta -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast import numpy as np import rasterio @@ -482,3 +482,33 @@ def rgb_to_mask( if isinstance(cmask, np.ndarray): mask[cmask.all(axis=-1)] = i return mask + + +def percentile_normalization( + img: np.ndarray, lower: float = 2, upper: float = 98 # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + """Applies percentile normalization to an input image. + + Specifically, this will rescale the values in the input such that values <= the + lower percentile value will be 0 and values >= the upper percentile value will be 1. + Using the 2nd and 98th percentile usually results in good visualizations. + + Args: + img: image to normalize + lower: lower percentile in range [0,100] + upper: upper percentile in range [0,100] + + Returns + normalized version of ``img`` as + """ + assert lower < upper + lower_percentile = np.percentile( # type: ignore[no-untyped-call] + img, lower, axis=(0, 1) + ) + upper_percentile = np.percentile( # type: ignore[no-untyped-call] + img, upper, axis=(0, 1) + ) + img_normalized = np.clip( + (img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1 + ) + return cast(np.ndarray, img_normalized) # type: ignore[type-arg] From 47f2234952a114fa7264935adc27c04c58db2d39 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 08:06:03 +0000 Subject: [PATCH 05/16] Adding RESISC45 plot --- tests/datasets/test_resisc45.py | 11 ++++ torchgeo/datasets/resisc45.py | 89 ++++++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index eadfe95b901..1b68dfae89c 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -91,6 +92,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): RESISC45(str(tmp_path)) + def test_plot(self, dataset: RESISC45) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestRESISC45DataModule: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index ebdf580c0a0..4eadde234ce 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,8 +4,10 @@ """RESISC45 dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch from torch import Tensor @@ -113,6 +115,53 @@ class RESISC45(VisionClassificationDataset): "val": "a0770cee4c5ca20b8c32bbd61e114805", "test": "3dda9e4988b47eb1de9f07993653eb08", } + classes = [ + "airplane", + "airport", + "baseball_diamond", + "basketball_court", + "beach", + "bridge", + "chaparral", + "church", + "circular_farmland", + "cloud", + "commercial_area", + "dense_residential", + "desert", + "forest", + "freeway", + "golf_course", + "ground_track_field", + "harbor", + "industrial_area", + "intersection", + "island", + "lake", + "meadow", + "medium_residential", + "mobile_home_park", + "mountain", + "overpass", + "palace", + "parking_lot", + "railway", + "railway_station", + "rectangular_farmland", + "river", + "roundabout", + "runway", + "sea_ice", + "ship", + "snowberg", + "sparse_residential", + "stadium", + "storage_tank", + "tennis_court", + "terrace", + "thermal_power_station", + "wetland", + ] def __init__( self, @@ -200,6 +249,44 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class RESISC45DataModule(pl.LightningDataModule): """LightningDataModule implementation for the RESISC45 dataset. From 3107ccde8b8aef58ae8c519567d75aaeb65f8b48 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 17:36:21 +0000 Subject: [PATCH 06/16] Adding versionadded tags and fixing docs --- torchgeo/datasets/eurosat.py | 4 +++- torchgeo/datasets/landcoverai.py | 2 ++ torchgeo/datasets/resisc45.py | 4 +++- torchgeo/datasets/utils.py | 4 +++- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index b9c294fc652..9ba06c1c683 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -197,12 +197,14 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`__getitem__` + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 """ image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy() image = np.clip(image / 3000, 0, 1) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index f98e43f974f..e579d668d63 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -234,6 +234,8 @@ def plot( Returns: a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 """ image = np.rollaxis(sample["image"].numpy(), 0, 3) mask = sample["mask"].numpy() diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 4eadde234ce..4b5c9560a0b 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -258,12 +258,14 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`__getitem__` + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 """ image = np.rollaxis(sample["image"].numpy(), 0, 3) label = cast(int, sample["label"].item()) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 508c2bbb828..436d3b636a8 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -499,7 +499,9 @@ def percentile_normalization( upper: upper percentile in range [0,100] Returns - normalized version of ``img`` as + normalized version of ``img`` + + .. versionadded:: 0.2 """ assert lower < upper lower_percentile = np.percentile( # type: ignore[no-untyped-call] From 53fa17ef724175d966e4afbdb0a5287e7df81e5d Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 18:08:57 +0000 Subject: [PATCH 07/16] So2Sat should return tensor labels --- tests/datasets/test_so2sat.py | 2 +- torchgeo/datasets/so2sat.py | 42 ++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 5d5e86fcc49..878da49e584 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -35,7 +35,7 @@ def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], int) + assert isinstance(x["label"], torch.Tensor) def test_len(self, dataset: So2Sat) -> None: assert len(dataset) == 10 diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index d0e03cef29c..b3559156eb5 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -42,6 +42,26 @@ class So2Sat(VisionDataset): * Validation: western half of 10 other cities covering 10 cultural zones * Testing: eastern half of the 10 other cities + Dataset classes: + + 0. Compact high rise + 1. Compact middle rise + 2. Compact low rise + 3. Open high rise + 4. Open mid rise + 5. Open low rise + 6. Lightweight low rise + 7. Large low rise + 8. Sparsely built + 9. Heavy industry + 10. Dense trees + 11. Scattered trees + 12. Bush, scrub + 13. Low plants + 14. Bare rock or paved + 15. Bare soil or sand + 16. Water + If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/MGRS.2020.2964708 @@ -71,6 +91,25 @@ class So2Sat(VisionDataset): "validation": "71cfa6795de3e22207229d06d6f8775d", "test": "e81426102b488623a723beab52b31a8a", } + classes = [ + "Compact high rise", + "Compact mid rise", + "Compact low rise", + "Open high rise", + "Open mid rise", + "Open low rise", + "Lightweight low rise", + "Large low rise", + "Sparsely built", + "Heavy industry", + "Dense trees", + "Scattered trees", + "Bush, scrub", + "Low plants", + "Bare rock or paved", + "Bare soil or sand", + "Water", + ] def __init__( self, @@ -123,7 +162,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: with h5py.File(self.fn, "r") as f: s1 = f["sen1"][index].astype(np.float64) # convert from Date: Fri, 19 Nov 2021 18:35:55 +0000 Subject: [PATCH 08/16] Added So2Sat plot --- tests/datasets/test_so2sat.py | 11 +++++++++ torchgeo/datasets/so2sat.py | 44 ++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 878da49e584..79aa62d7f2b 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -54,6 +55,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): So2Sat(str(tmp_path)) + def test_plot(self, dataset: So2Sat) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestSo2SatDataModule: @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index b3559156eb5..32eb8ad59d3 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -6,6 +6,7 @@ import os from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch @@ -14,7 +15,7 @@ from torchvision.transforms import Compose from .geo import VisionDataset -from .utils import check_integrity +from .utils import check_integrity, percentile_normalization # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 @@ -204,6 +205,47 @@ def _check_integrity(self) -> bool: return False return True + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"][[10, 9, 8]].numpy(), 0, 3) + image = percentile_normalization(image, 0, 100) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class So2SatDataModule(pl.LightningDataModule): """LightningDataModule implementation for the So2Sat dataset. From 58898b8e96211dbc5ab6d4a72ae036b6244a24c0 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 18:37:39 +0000 Subject: [PATCH 09/16] Added UCMerced plot --- tests/datasets/test_ucmerced.py | 11 +++++++++ torchgeo/datasets/ucmerced.py | 44 ++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index 01963de7d3e..ad6efb6628b 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -92,6 +93,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): UCMerced(str(tmp_path)) + def test_plot(self, dataset: UCMerced) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestUCMercedDataModule: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index ac672c0da70..431b526b756 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,8 +3,10 @@ """UC Merced dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch import torchvision @@ -210,6 +212,46 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class UCMercedDataModule(pl.LightningDataModule): """LightningDataModule implementation for the UC Merced dataset. From 5ea1c21264115634134f61edfcf7bcdc1a77a13d Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 18:37:57 +0000 Subject: [PATCH 10/16] Changed percentile normalization to calculate values overall and not by band --- torchgeo/datasets/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 436d3b636a8..4f2048b16c2 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -504,12 +504,8 @@ def percentile_normalization( .. versionadded:: 0.2 """ assert lower < upper - lower_percentile = np.percentile( # type: ignore[no-untyped-call] - img, lower, axis=(0, 1) - ) - upper_percentile = np.percentile( # type: ignore[no-untyped-call] - img, upper, axis=(0, 1) - ) + lower_percentile = np.percentile(img, lower) # type: ignore[no-untyped-call] + upper_percentile = np.percentile(img, upper) # type: ignore[no-untyped-call] img_normalized = np.clip( (img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1 ) From 48a774fee10b1587d1dd2ac1d11edc64357dba13 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 19:08:00 +0000 Subject: [PATCH 11/16] Added SeCo plot --- tests/datasets/test_seco.py | 17 ++++++++++++ torchgeo/datasets/seco.py | 52 ++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index a95b43ad7f4..585936dfaf4 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -90,3 +91,19 @@ def test_invalid_band(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): SeasonalContrastS2(str(tmp_path)) + + def test_plot(self, dataset: SeasonalContrastS2) -> None: + if not all(band in dataset.bands for band in dataset.RGB_BANDS): + with pytest.raises(ValueError, match="Dataset doesn't contain"): + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + else: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + + with pytest.raises(ValueError, match="doesn't support plotting"): + x["prediction"] = torch.tensor(1) # type: ignore[attr-defined] + dataset.plot(x) diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 2afb5a838ae..f196a7981d2 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -7,6 +7,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, cast +import matplotlib.pyplot as plt import numpy as np import rasterio import torch @@ -14,7 +15,7 @@ from torch import Tensor from .geo import VisionDataset -from .utils import download_url, extract_archive +from .utils import download_url, extract_archive, percentile_normalization class SeasonalContrastS2(VisionDataset): @@ -229,3 +230,52 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.filename)) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are included in ``self.bands`` or the sample + contains a "prediction" key + + .. versionadded:: 0.2 + """ + if "prediction" in sample: + raise ValueError("This dataset doesn't support plotting predictions") + + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + images = [] + for i in range(5): + image = np.rollaxis(sample["image"][i, rgb_indices].numpy(), 0, 3) + image = percentile_normalization(image, 0, 100) + images.append(image) + + fig, axs = plt.subplots(ncols=5, figsize=(20, 4)) + for i in range(5): + axs[i].imshow(images[i]) + axs[i].axis("off") + if show_titles: + axs[i].set_title(f"t={i+1}") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig From 0b2082c0a183bc10ad5d3dda8d8fb4a06370e966 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 19:08:18 +0000 Subject: [PATCH 12/16] Fixed So2Sat doc --- torchgeo/datasets/so2sat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 32eb8ad59d3..0244bab2b12 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -214,7 +214,7 @@ def plot( """Plot a sample from the dataset. Args: - sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle From 9837b47c27f03721f9803fe33f0f13e16d1e9cec Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 21:59:14 +0000 Subject: [PATCH 13/16] Adding version to xview2 --- torchgeo/datasets/xview.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 9c84598f083..b8945e0479f 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -11,7 +11,6 @@ import numpy as np import pytorch_lightning as pl import torch -from matplotlib.figure import Figure from PIL import Image from torch import Tensor from torch.utils.data import DataLoader @@ -230,7 +229,7 @@ def plot( show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, - ) -> Figure: + ) -> plt.Figure: """Plot a sample from the dataset. Args: @@ -241,6 +240,8 @@ def plot( Returns: a matplotlib Figure with the rendered sample + + .. versionadded: 0.2 """ ncols = 2 image1 = draw_semantic_segmentation_masks( From 7cc5f7be6e8a64f6497797b2a2fddd1abeb4618e Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 21:59:29 +0000 Subject: [PATCH 14/16] Testing percentile_normalization --- tests/datasets/test_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 51ea5937857..e719fdb9607 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any, Generator, Tuple +import numpy as np import pytest import torch from _pytest.monkeypatch import MonkeyPatch @@ -29,6 +30,7 @@ download_radiant_mlhub_dataset, extract_archive, working_dir, + percentile_normalization, ) @@ -361,3 +363,14 @@ def test_dataset_split() -> None: assert len(train_ds) == num_samples // 3 assert len(val_ds) == num_samples // 3 assert len(test_ds) == num_samples // 3 + + +def test_percentile_normalization() -> None: + img = np.array([ + [1, 2], + [98, 100], + ]) + + img = percentile_normalization(img, 2, 98) + assert img.min() == 0 + assert img.max() == 1 From 6cf2516ead631fe373b3cdb5c524165ba05e2b60 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 22:01:39 +0000 Subject: [PATCH 15/16] Style in test_utils --- tests/datasets/test_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index e719fdb9607..362a9ffbeaa 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -29,8 +29,8 @@ download_radiant_mlhub_collection, download_radiant_mlhub_dataset, extract_archive, - working_dir, percentile_normalization, + working_dir, ) @@ -366,10 +366,7 @@ def test_dataset_split() -> None: def test_percentile_normalization() -> None: - img = np.array([ - [1, 2], - [98, 100], - ]) + img = np.array([[1, 2], [98, 100]]) img = percentile_normalization(img, 2, 98) assert img.min() == 0 From d34f9f417b4e4e67889de5cf37e821d61a206663 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 19 Nov 2021 22:09:24 +0000 Subject: [PATCH 16/16] version de-added --- torchgeo/datasets/xview.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index b8945e0479f..c4e7774e04d 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -240,8 +240,6 @@ def plot( Returns: a matplotlib Figure with the rendered sample - - .. versionadded: 0.2 """ ncols = 2 image1 = draw_semantic_segmentation_masks(