Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL4EO-L: add additional metadata #2535

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/datasets/test_ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def test_getitem(self, dataset: SSL4EOL) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert (
x['image'].size(0)
== dataset.seasons * dataset.metadata[dataset.split]['num_bands']
assert x['image'].size(0) == dataset.seasons * len(
dataset.metadata[dataset.split]['all_bands']
)

def test_len(self, dataset: SSL4EOL) -> None:
Expand Down
59 changes: 57 additions & 2 deletions torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import abc
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from typing import Any, ClassVar

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -58,6 +58,12 @@ class Landsat(RasterDataset, abc.ABC):
def default_bands(self) -> tuple[str, ...]:
"""Bands to load by default."""

# https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites
@property
@abc.abstractmethod
def wavelengths(self) -> dict[str, float]:
"""Central wavelength (μm)."""

def __init__(
self,
paths: Path | Iterable[Path] = 'data',
Expand Down Expand Up @@ -148,6 +154,13 @@ class Landsat1(Landsat):
default_bands = ('B4', 'B5', 'B6', 'B7')
rgb_bands = ('B6', 'B5', 'B4')

wavelengths: ClassVar[dict[str, float]] = {
'B4': (0.5 + 0.6) / 2,
'B5': (0.6 + 0.7) / 2,
'B6': (0.7 + 0.8) / 2,
'B7': (0.8 + 1.1) / 2,
}


class Landsat2(Landsat1):
"""Landsat 2 Multispectral Scanner (MSS)."""
Expand All @@ -169,6 +182,13 @@ class Landsat4MSS(Landsat):
default_bands = ('B1', 'B2', 'B3', 'B4')
rgb_bands = ('B3', 'B2', 'B1')

wavelengths: ClassVar[dict[str, float]] = {
'B1': (0.5 + 0.6) / 2,
'B2': (0.6 + 0.7) / 2,
'B3': (0.7 + 0.8) / 2,
'B4': (0.8 + 1.1) / 2,
}


class Landsat4TM(Landsat):
"""Landsat 4 Thematic Mapper (TM)."""
Expand All @@ -178,6 +198,16 @@ class Landsat4TM(Landsat):
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')

wavelengths: ClassVar[dict[str, float]] = {
'B1': (0.45 + 0.52) / 2,
'B2': (0.52 + 0.60) / 2,
'B3': (0.63 + 0.69) / 2,
'B4': (0.76 + 0.90) / 2,
'B5': (1.55 + 1.75) / 2,
'B6': (10.40 + 12.50) / 2,
'B7': (2.08 + 2.35) / 2,
}


class Landsat5MSS(Landsat4MSS):
"""Landsat 4 Multispectral Scanner (MSS)."""
Expand All @@ -196,9 +226,20 @@ class Landsat7(Landsat):

filename_glob = 'LE07_*_{}.*'

default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7')
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')

wavelengths: ClassVar[dict[str, float]] = {
'B1': (0.45 + 0.52) / 2,
'B2': (0.52 + 0.60) / 2,
'B3': (0.63 + 0.69) / 2,
'B4': (0.77 + 0.90) / 2,
'B5': (1.55 + 1.75) / 2,
'B6': (10.40 + 12.50) / 2,
'B7': (2.09 + 2.35) / 2,
'B8': (0.52 + 0.90) / 2,
}


class Landsat8(Landsat):
"""Landsat 8 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
Expand All @@ -208,6 +249,20 @@ class Landsat8(Landsat):
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2')

wavelengths: ClassVar[dict[str, float]] = {
'B1': (0.43 + 0.45) / 2,
'B2': (0.45 + 0.51) / 2,
'B3': (0.53 + 0.59) / 2,
'B4': (0.64 + 0.67) / 2,
'B5': (0.85 + 0.88) / 2,
'B6': (1.57 + 1.65) / 2,
'B7': (2.11 + 2.29) / 2,
'B8': (0.50 + 0.68) / 2,
'B9': (1.36 + 1.38) / 2,
'B10': (10.6 + 11.19) / 2,
'B11': (11.50 + 12.51) / 2,
}


class Landsat9(Landsat8):
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2)."""
Expand Down
85 changes: 74 additions & 11 deletions torchgeo/datasets/ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, check_integrity, download_url, extract_archive
from .landsat import Landsat, Landsat5TM, Landsat7, Landsat8
from .utils import (
Path,
check_integrity,
disambiguate_timestamp,
download_url,
extract_archive,
)


class SSL4EO(NonGeoDataset):
Expand All @@ -32,7 +39,7 @@ class SSL4EO(NonGeoDataset):
"""


class SSL4EOL(NonGeoDataset):
class SSL4EOL(SSL4EO):
"""SSL4EO-L dataset.

Landsat version of SSL4EO.
Expand Down Expand Up @@ -96,15 +103,42 @@ class SSL4EOL(NonGeoDataset):
"""

class _Metadata(TypedDict):
num_bands: int
all_bands: list[str]
rgb_bands: list[int]

metadata: ClassVar[dict[str, _Metadata]] = {
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
'oli_tirs_toa': {'num_bands': 11, 'rgb_bands': [3, 2, 1]},
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
'tm_toa': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
'rgb_bands': [2, 1, 0],
},
'etm_toa': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B6', 'B7', 'B8'],
'rgb_bands': [2, 1, 0],
},
'etm_sr': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
'rgb_bands': [2, 1, 0],
},
'oli_tirs_toa': {
'all_bands': [
'B1',
'B2',
'B3',
'B4',
'B5',
'B6',
'B7',
'B8',
'B9',
'B10',
'B11',
],
'rgb_bands': [3, 2, 1],
},
'oli_sr': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
'rgb_bands': [3, 2, 1],
},
}

url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
Expand Down Expand Up @@ -197,6 +231,17 @@ def __init__(

self._verify()

if split.startswith('tm'):
base: type[Landsat] = Landsat5TM
elif split.startswith('etm'):
base = Landsat7
else:
base = Landsat8

self.wavelengths = []
for band in self.metadata[split]['all_bands']:
self.wavelengths.append(base.wavelengths[band])

self.scenes = sorted(os.listdir(self.subdir))

def __getitem__(self, index: int) -> dict[str, Tensor]:
Expand All @@ -213,14 +258,32 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
subdirs = random.sample(subdirs, self.seasons)

images = []
xs = []
ys = []
ts = []
wavelengths = []
for subdir in subdirs:
mint, maxt = disambiguate_timestamp(subdir[-8:], Landsat.date_format)
directory = os.path.join(root, subdir)
filename = os.path.join(directory, 'all_bands.tif')
with rasterio.open(filename) as f:
minx, maxx = f.bounds.left, f.bounds.right
miny, maxy = f.bounds.bottom, f.bounds.top
image = f.read()
images.append(torch.from_numpy(image.astype(np.float32)))

sample = {'image': torch.cat(images)}
xs.append((minx + maxx) / 2)
ys.append((miny + maxy) / 2)
ts.append((mint + maxt) / 2)
wavelengths.extend(self.wavelengths)

sample = {
'image': torch.cat(images),
'x': torch.tensor(xs),
'y': torch.tensor(ys),
't': torch.tensor(ts),
'wavelength': torch.tensor(wavelengths),
'res': torch.tensor(30),
}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down Expand Up @@ -302,7 +365,7 @@ def plot(
fig, axes = plt.subplots(
ncols=self.seasons, squeeze=False, figsize=(4 * self.seasons, 4)
)
num_bands = self.metadata[self.split]['num_bands']
num_bands = len(self.metadata[self.split]['all_bands'])
rgb_bands = self.metadata[self.split]['rgb_bands']

for i in range(self.seasons):
Expand Down
Loading