Skip to content

Commit

Permalink
Landsat: add plot method (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jul 10, 2022
1 parent ee657ba commit 7e7443a
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Dataset,Type,Source,Size (px),Resolution (m)
`GBIF`_,Points,Citizen Scientists,-,-
`GlobBiomass`_,Masks,Landsat,"45,000x45,000",100
`iNaturalist`_,Points,Citizen Scientists,-,-
`Landsat`_,Imagery,Landsat,-,30
`Landsat`_,Imagery,Landsat,"8,900x8,900",30
`NAIP`_,Imagery,Aerial,"6,100x7,600",1
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,-
`Sentinel`_,Imagery,Sentinel,"10,000x10,000",10
15 changes: 15 additions & 0 deletions tests/datasets/test_landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -38,6 +39,20 @@ def test_or(self, dataset: Landsat8) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: Landsat8) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_wrong_bands(self, dataset: Landsat8) -> None:
bands = ("SR_B1",)
ds = Landsat8(root=dataset.root, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds.plot(x)

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "):
Landsat8(str(tmp_path))
Expand Down
4 changes: 4 additions & 0 deletions torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
mask = sample["mask"].squeeze().numpy()
ncols = 1
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from matplotlib.colors import ListedColormap
from rasterio.crs import CRS
from torch import Tensor

from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive
Expand Down Expand Up @@ -178,7 +177,7 @@ def _extract(self) -> None:

def plot(
self,
sample: Dict[str, Tensor],
sample: Dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
Expand All @@ -192,7 +191,9 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.3
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
mask = sample["mask"].squeeze(0)
ncols = 1
Expand Down
49 changes: 49 additions & 0 deletions torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import abc
from typing import Any, Callable, Dict, Optional, Sequence

import matplotlib.pyplot as plt
from rasterio.crs import CRS

from .geo import RasterDataset
Expand Down Expand Up @@ -78,6 +79,54 @@ def __init__(

super().__init__(root, crs, res, transforms, cache)

def plot(
self,
sample: Dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if the RGB bands are not included in ``self.bands``
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
rgb_indices = []
for band in self.rgb_bands:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")

image = sample["image"][rgb_indices].permute(1, 2, 0).float()

# Stretch to the full range
image = (image - image.min()) / (image.max() - image.min())

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Image")

if suptitle is not None:
plt.suptitle(suptitle)

return fig


class Landsat1(Landsat):
"""Landsat 1 Multispectral Scanner (MSS)."""
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def plot(
a matplotlib Figure with the rendered sample
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
image = sample["image"][0:3, :, :].permute(1, 2, 0)

Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import matplotlib.pyplot as plt
import torch
from rasterio.crs import CRS
from torch import Tensor

from .geo import RasterDataset

Expand Down Expand Up @@ -104,7 +103,7 @@ def __init__(

def plot(
self,
sample: Dict[str, Tensor],
sample: Dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
Expand All @@ -121,7 +120,9 @@ def plot(
Raises:
ValueError: if the RGB bands are not included in ``self.bands``
.. versionadded:: 0.3
.. versionchanged:: 0.3
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
rgb_indices = []
for band in self.RGB_BANDS:
Expand Down

0 comments on commit 7e7443a

Please sign in to comment.