diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 4e204b93c7f..008195bb72a 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -90,6 +91,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): EuroSAT(str(tmp_path)) + def test_plot(self, dataset: EuroSAT) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestEuroSATDataModule: @pytest.fixture(scope="class") diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index f197077f225..8a6942e5800 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -68,6 +69,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LandCoverAI(str(tmp_path)) + def test_plot(self, dataset: LandCoverAI) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + class TestLandCoverAIDataModule: @pytest.fixture(scope="class") diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index eadfe95b901..1b68dfae89c 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -91,6 +92,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): RESISC45(str(tmp_path)) + def test_plot(self, dataset: RESISC45) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestRESISC45DataModule: @pytest.fixture(scope="class") diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index a95b43ad7f4..585936dfaf4 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -90,3 +91,19 @@ def test_invalid_band(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found"): SeasonalContrastS2(str(tmp_path)) + + def test_plot(self, dataset: SeasonalContrastS2) -> None: + if not all(band in dataset.bands for band in dataset.RGB_BANDS): + with pytest.raises(ValueError, match="Dataset doesn't contain"): + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + else: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + + with pytest.raises(ValueError, match="doesn't support plotting"): + x["prediction"] = torch.tensor(1) # type: ignore[attr-defined] + dataset.plot(x) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 5d5e86fcc49..79aa62d7f2b 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -35,7 +36,7 @@ def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], int) + assert isinstance(x["label"], torch.Tensor) def test_len(self, dataset: So2Sat) -> None: assert len(dataset) == 10 @@ -54,6 +55,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): So2Sat(str(tmp_path)) + def test_plot(self, dataset: So2Sat) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestSo2SatDataModule: @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index 01963de7d3e..ad6efb6628b 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -92,6 +93,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): UCMerced(str(tmp_path)) + def test_plot(self, dataset: UCMerced) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestUCMercedDataModule: @pytest.fixture(scope="class") diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 51ea5937857..362a9ffbeaa 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any, Generator, Tuple +import numpy as np import pytest import torch from _pytest.monkeypatch import MonkeyPatch @@ -28,6 +29,7 @@ download_radiant_mlhub_collection, download_radiant_mlhub_dataset, extract_archive, + percentile_normalization, working_dir, ) @@ -361,3 +363,11 @@ def test_dataset_split() -> None: assert len(train_ds) == num_samples // 3 assert len(val_ds) == num_samples // 3 assert len(test_ds) == num_samples // 3 + + +def test_percentile_normalization() -> None: + img = np.array([[1, 2], [98, 100]]) + + img = percentile_normalization(img, 2, 98) + assert img.min() == 0 + assert img.max() == 1 diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 56d25622ca8..bb10da22bff 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -11,7 +11,6 @@ import numpy as np import pytorch_lightning as pl import torch -from matplotlib.figure import Figure from PIL import Image from torch import Generator, Tensor # type: ignore[attr-defined] from torch.utils.data import DataLoader, random_split @@ -265,7 +264,7 @@ def plot( sample: Dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, - ) -> Figure: + ) -> plt.Figure: """Plot a sample from the dataset. Args: @@ -280,18 +279,16 @@ def plot( vh = np.rollaxis(sample["image"][3:].numpy(), 0, 3) water_mask = sample["mask"][0].numpy() - showing_flood_mask = False - showing_predictions = False + showing_flood_mask = sample["mask"].shape[0] > 1 + showing_predictions = "prediction" in sample num_panels = 3 - if sample["mask"].shape[0] > 1: + if showing_flood_mask: flood_mask = sample["mask"][1].numpy() num_panels += 1 - showing_flood_mask = True - if "prediction" in sample: + if showing_predictions: predictions = sample["prediction"].numpy() num_panels += 1 - showing_predictions = True fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 3)) axs[0].imshow(vv) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 2a77e8c0ef8..9ba06c1c683 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,8 +4,10 @@ """EuroSAT dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch from torch import Tensor @@ -72,6 +74,18 @@ class EuroSAT(VisionClassificationDataset): "val": "95de90f2aa998f70a3b2416bfe0687b4", "test": "7ae5ab94471417b6e315763121e67c5f", } + classes = [ + "Industrial Buildings", + "Residential Buildings", + "Annual Crop", + "Permanent Crop", + "River", + "Sea and Lake", + "Herbaceous Vegetation", + "Highway", + "Pasture", + "Forest", + ] def __init__( self, @@ -174,6 +188,48 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3).copy() + image = np.clip(image / 3000, 0, 1) + + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class EuroSATDataModule(pl.LightningDataModule): """LightningDataModule implementation for the EuroSAT dataset. diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 771e87f5c9a..e579d668d63 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -8,9 +8,11 @@ from functools import lru_cache from typing import Any, Callable, Dict, Optional +import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch +from matplotlib.colors import ListedColormap from PIL import Image from torch import Tensor from torch.utils.data import DataLoader @@ -67,6 +69,16 @@ class LandCoverAI(VisionDataset): filename = "landcover.ai.v1.zip" md5 = "3268c89070e8734b4e91d531c0617e03" sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1" + classes = ["Background", "Building", "Woodland", "Water", "Road"] + cmap = ListedColormap( + [ + [0.63921569, 1.0, 0.45098039], + [0.61176471, 0.61176471, 0.61176471], + [0.14901961, 0.45098039, 0.0], + [0.0, 0.77254902, 1.0], + [0.0, 0.0, 0.0], + ] + ) def __init__( self, @@ -207,6 +219,54 @@ def _download(self) -> None: assert hashlib.sha256(split).hexdigest() == self.sha256 exec(split) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + mask = sample["mask"].numpy() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4, cmap=self.cmap, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow( + predictions, vmin=0, vmax=4, cmap=self.cmap, interpolation="none" + ) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class LandCoverAIDataModule(pl.LightningDataModule): """LightningDataModule implementation for the LandCover.ai dataset. diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index ebdf580c0a0..4b5c9560a0b 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,8 +4,10 @@ """RESISC45 dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch from torch import Tensor @@ -113,6 +115,53 @@ class RESISC45(VisionClassificationDataset): "val": "a0770cee4c5ca20b8c32bbd61e114805", "test": "3dda9e4988b47eb1de9f07993653eb08", } + classes = [ + "airplane", + "airport", + "baseball_diamond", + "basketball_court", + "beach", + "bridge", + "chaparral", + "church", + "circular_farmland", + "cloud", + "commercial_area", + "dense_residential", + "desert", + "forest", + "freeway", + "golf_course", + "ground_track_field", + "harbor", + "industrial_area", + "intersection", + "island", + "lake", + "meadow", + "medium_residential", + "mobile_home_park", + "mountain", + "overpass", + "palace", + "parking_lot", + "railway", + "railway_station", + "rectangular_farmland", + "river", + "roundabout", + "runway", + "sea_ice", + "ship", + "snowberg", + "sparse_residential", + "stadium", + "storage_tank", + "tennis_court", + "terrace", + "thermal_power_station", + "wetland", + ] def __init__( self, @@ -200,6 +249,46 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class RESISC45DataModule(pl.LightningDataModule): """LightningDataModule implementation for the RESISC45 dataset. diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 2afb5a838ae..f196a7981d2 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -7,6 +7,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, cast +import matplotlib.pyplot as plt import numpy as np import rasterio import torch @@ -14,7 +15,7 @@ from torch import Tensor from .geo import VisionDataset -from .utils import download_url, extract_archive +from .utils import download_url, extract_archive, percentile_normalization class SeasonalContrastS2(VisionDataset): @@ -229,3 +230,52 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.filename)) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are included in ``self.bands`` or the sample + contains a "prediction" key + + .. versionadded:: 0.2 + """ + if "prediction" in sample: + raise ValueError("This dataset doesn't support plotting predictions") + + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + images = [] + for i in range(5): + image = np.rollaxis(sample["image"][i, rgb_indices].numpy(), 0, 3) + image = percentile_normalization(image, 0, 100) + images.append(image) + + fig, axs = plt.subplots(ncols=5, figsize=(20, 4)) + for i in range(5): + axs[i].imshow(images[i]) + axs[i].axis("off") + if show_titles: + axs[i].set_title(f"t={i+1}") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index d0e03cef29c..0244bab2b12 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -6,6 +6,7 @@ import os from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch @@ -14,7 +15,7 @@ from torchvision.transforms import Compose from .geo import VisionDataset -from .utils import check_integrity +from .utils import check_integrity, percentile_normalization # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 @@ -42,6 +43,26 @@ class So2Sat(VisionDataset): * Validation: western half of 10 other cities covering 10 cultural zones * Testing: eastern half of the 10 other cities + Dataset classes: + + 0. Compact high rise + 1. Compact middle rise + 2. Compact low rise + 3. Open high rise + 4. Open mid rise + 5. Open low rise + 6. Lightweight low rise + 7. Large low rise + 8. Sparsely built + 9. Heavy industry + 10. Dense trees + 11. Scattered trees + 12. Bush, scrub + 13. Low plants + 14. Bare rock or paved + 15. Bare soil or sand + 16. Water + If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/MGRS.2020.2964708 @@ -71,6 +92,25 @@ class So2Sat(VisionDataset): "validation": "71cfa6795de3e22207229d06d6f8775d", "test": "e81426102b488623a723beab52b31a8a", } + classes = [ + "Compact high rise", + "Compact mid rise", + "Compact low rise", + "Open high rise", + "Open mid rise", + "Open low rise", + "Lightweight low rise", + "Large low rise", + "Sparsely built", + "Heavy industry", + "Dense trees", + "Scattered trees", + "Bush, scrub", + "Low plants", + "Bare rock or paved", + "Bare soil or sand", + "Water", + ] def __init__( self, @@ -123,7 +163,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: with h5py.File(self.fn, "r") as f: s1 = f["sen1"][index].astype(np.float64) # convert from bool: return False return True + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"][[10, 9, 8]].numpy(), 0, 3) + image = percentile_normalization(image, 0, 100) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class So2SatDataModule(pl.LightningDataModule): """LightningDataModule implementation for the So2Sat dataset. diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index ac672c0da70..431b526b756 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,8 +3,10 @@ """UC Merced dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast +import matplotlib.pyplot as plt +import numpy as np import pytorch_lightning as pl import torch import torchvision @@ -210,6 +212,46 @@ def _extract(self) -> None: filepath = os.path.join(self.root, self.filename) extract_archive(filepath) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + class UCMercedDataModule(pl.LightningDataModule): """LightningDataModule implementation for the UC Merced dataset. diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index a888805c1a1..4f2048b16c2 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -12,7 +12,7 @@ import tarfile import zipfile from datetime import datetime, timedelta -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast import numpy as np import rasterio @@ -482,3 +482,31 @@ def rgb_to_mask( if isinstance(cmask, np.ndarray): mask[cmask.all(axis=-1)] = i return mask + + +def percentile_normalization( + img: np.ndarray, lower: float = 2, upper: float = 98 # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + """Applies percentile normalization to an input image. + + Specifically, this will rescale the values in the input such that values <= the + lower percentile value will be 0 and values >= the upper percentile value will be 1. + Using the 2nd and 98th percentile usually results in good visualizations. + + Args: + img: image to normalize + lower: lower percentile in range [0,100] + upper: upper percentile in range [0,100] + + Returns + normalized version of ``img`` + + .. versionadded:: 0.2 + """ + assert lower < upper + lower_percentile = np.percentile(img, lower) # type: ignore[no-untyped-call] + upper_percentile = np.percentile(img, upper) # type: ignore[no-untyped-call] + img_normalized = np.clip( + (img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1 + ) + return cast(np.ndarray, img_normalized) # type: ignore[type-arg] diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 9c84598f083..c4e7774e04d 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -11,7 +11,6 @@ import numpy as np import pytorch_lightning as pl import torch -from matplotlib.figure import Figure from PIL import Image from torch import Tensor from torch.utils.data import DataLoader @@ -230,7 +229,7 @@ def plot( show_titles: bool = True, suptitle: Optional[str] = None, alpha: float = 0.5, - ) -> Figure: + ) -> plt.Figure: """Plot a sample from the dataset. Args: