Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RasterDataset: fix band indexing bug #1135

Merged
merged 9 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class CustomVectorDataset(VectorDataset):

class CustomSentinelDataset(Sentinel2):
all_bands: List[str] = []
separate_files = False


class CustomNonGeoDataset(NonGeoDataset):
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
RasterDataset(str(tmp_path))

def test_no_allbands(self) -> None:
def test_no_all_bands(self) -> None:
root = os.path.join("tests", "data", "sentinel2")
bands = ["B04", "B03", "B02"]
transforms = nn.Identity()
Expand Down
12 changes: 9 additions & 3 deletions tests/datasets/test_landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset


class TestLandsat8:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> Landsat8:
@pytest.fixture(
params=[
["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"],
["SR_B4", "SR_B3", "SR_B2", "SR_QA_AEROSOL"],
]
)
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8:
root = os.path.join("tests", "data", "landsat8")
bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
bands = request.param
transforms = nn.Identity()
return Landsat8(root, bands=bands, transforms=transforms)

Expand Down
28 changes: 14 additions & 14 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(
super().__init__(transforms)

self.root = root
self.bands = bands or self.all_bands
self.cache = cache

# Populate the dataset index
Expand Down Expand Up @@ -367,21 +368,20 @@ def __init__(
f"No {self.__class__.__name__} data was found in '{root}'"
)

if bands and self.all_bands:
band_indexes = [self.all_bands.index(i) + 1 for i in bands]
self.bands = bands
assert len(band_indexes) == len(self.bands)
elif bands:
msg = (
f"{self.__class__.__name__} is missing an `all_bands` attribute,"
" so `bands` cannot be specified."
)
raise AssertionError(msg)
else:
band_indexes = None
self.bands = self.all_bands
if not self.separate_files:
self.band_indexes = None
if self.bands:
if self.all_bands:
self.band_indexes = [
self.all_bands.index(i) + 1 for i in self.bands
]
else:
msg = (
f"{self.__class__.__name__} is missing an `all_bands` "
"attribute, so `bands` cannot be specified."
)
raise AssertionError(msg)

self.band_indexes = band_indexes
self._crs = cast(CRS, crs)
self.res = cast(float, res)

Expand Down
28 changes: 21 additions & 7 deletions torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Landsat datasets."""

import abc
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence

import matplotlib.pyplot as plt
from rasterio.crs import CRS
Expand Down Expand Up @@ -49,6 +49,11 @@ class Landsat(RasterDataset, abc.ABC):

separate_files = True

@property
@abc.abstractmethod
def default_bands(self) -> List[str]:
"""Bands to load by default."""

def __init__(
self,
root: str = "data",
Expand All @@ -74,7 +79,7 @@ def __init__(
Raises:
FileNotFoundError: if no files are found in ``root``
"""
bands = bands or self.all_bands
bands = bands or self.default_bands
self.filename_glob = self.filename_glob.format(bands[0])

super().__init__(root, crs, res, bands, transforms, cache)
Expand Down Expand Up @@ -133,7 +138,7 @@ class Landsat1(Landsat):

filename_glob = "LM01_*_{}.*"

all_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"]
default_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"]
rgb_bands = ["SR_B6", "SR_B5", "SR_B4"]


Expand All @@ -154,7 +159,7 @@ class Landsat4MSS(Landsat):

filename_glob = "LM04_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"]
default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -163,7 +168,7 @@ class Landsat4TM(Landsat):

filename_glob = "LT04_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -184,7 +189,16 @@ class Landsat7(Landsat):

filename_glob = "LE07_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "SR_B8"]
default_bands = [
"SR_B1",
"SR_B2",
"SR_B3",
"SR_B4",
"SR_B5",
"SR_B6",
"SR_B7",
"SR_B8",
]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -193,7 +207,7 @@ class Landsat8(Landsat):

filename_glob = "LC08_*_{}.*"

all_bands = [
default_bands = [
"SR_B1",
"SR_B2",
"SR_B3",
Expand Down