From 7229dd401ecb911f698e2883da6f28f7bbc785b8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 11:42:34 -0600 Subject: [PATCH 1/9] Fix Landsat non-SR band specification --- tests/datasets/test_landsat.py | 12 +++++++++--- torchgeo/datasets/geo.py | 27 ++++++++++++--------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index bc1ff2c8aea..e82151a05e6 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS @@ -15,10 +16,15 @@ class TestLandsat8: - @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch) -> Landsat8: + @pytest.fixture( + params=[ + ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"], + ["SR_B4", "SR_B3", "SR_B2", "SR_QA_AEROSOL"], + ] + ) + def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8: root = os.path.join("tests", "data", "landsat8") - bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] + bands = request.param transforms = nn.Identity() return Landsat8(root, bands=bands, transforms=transforms) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 96ae41e20a3..0ae36fb5199 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -323,6 +323,7 @@ def __init__( super().__init__(transforms) self.root = root + self.bands = bands or self.all_bands self.cache = cache # Populate the dataset index @@ -367,21 +368,6 @@ def __init__( f"No {self.__class__.__name__} data was found in '{root}'" ) - if bands and self.all_bands: - band_indexes = [self.all_bands.index(i) + 1 for i in bands] - self.bands = bands - assert len(band_indexes) == len(self.bands) - elif bands: - msg = ( - f"{self.__class__.__name__} is missing an `all_bands` attribute," - " so `bands` cannot be specified." - ) - raise AssertionError(msg) - else: - band_indexes = None - self.bands = self.all_bands - - self.band_indexes = band_indexes self._crs = cast(CRS, crs) self.res = cast(float, res) @@ -424,6 +410,17 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: data_list.append(self._merge_files(band_filepaths, query)) data = torch.cat(data_list) else: + if self.bands and self.all_bands: + band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] + elif self.bands: + msg = ( + f"{self.__class__.__name__} is missing an `all_bands` attribute," + " so `bands` cannot be specified." + ) + raise AssertionError(msg) + else: + band_indexes = None + data = self._merge_files(filepaths, query, self.band_indexes) sample = {"crs": self.crs, "bbox": query} From 27432b059da711e6094a8830eead5e26d8d77bcb Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 11:59:26 -0600 Subject: [PATCH 2/9] Fix variable reference --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 0ae36fb5199..754de646b6b 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -421,7 +421,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: else: band_indexes = None - data = self._merge_files(filepaths, query, self.band_indexes) + data = self._merge_files(filepaths, query, band_indexes) sample = {"crs": self.crs, "bbox": query} if self.is_image: From 8638c0f6c729e77b36d76404154661c58298e76a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 12:26:22 -0600 Subject: [PATCH 3/9] Fix test when no all_bands --- tests/datasets/test_geo.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index a24dd3653f2..658a38806fa 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -52,6 +52,7 @@ class CustomVectorDataset(VectorDataset): class CustomSentinelDataset(Sentinel2): all_bands: List[str] = [] + separate_files = False class CustomNonGeoDataset(NonGeoDataset): @@ -214,18 +215,21 @@ def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"): RasterDataset(str(tmp_path)) - def test_no_allbands(self) -> None: + def test_no_all_bands(self) -> None: root = os.path.join("tests", "data", "sentinel2") bands = ["B04", "B03", "B02"] transforms = nn.Identity() cache = True + ds = CustomSentinelDataset( + root, bands=bands, transforms=transforms, cache=cache + ) + msg = ( "CustomSentinelDataset is missing an `all_bands` attribute," " so `bands` cannot be specified." ) - with pytest.raises(AssertionError, match=msg): - CustomSentinelDataset(root, bands=bands, transforms=transforms, cache=cache) + ds[ds.index.bounds] class TestVectorDataset: From 28c64c80db839e99e39169c3af3490ba528ce8fb Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 12:50:40 -0600 Subject: [PATCH 4/9] Fail during init instead --- tests/datasets/test_geo.py | 7 ++----- torchgeo/datasets/geo.py | 25 +++++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 658a38806fa..158117c9d88 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -220,16 +220,13 @@ def test_no_all_bands(self) -> None: bands = ["B04", "B03", "B02"] transforms = nn.Identity() cache = True - ds = CustomSentinelDataset( - root, bands=bands, transforms=transforms, cache=cache - ) - msg = ( "CustomSentinelDataset is missing an `all_bands` attribute," " so `bands` cannot be specified." ) + with pytest.raises(AssertionError, match=msg): - ds[ds.index.bounds] + CustomSentinelDataset(root, bands=bands, transforms=transforms, cache=cache) class TestVectorDataset: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 754de646b6b..c06f7c3edbb 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -368,6 +368,18 @@ def __init__( f"No {self.__class__.__name__} data was found in '{root}'" ) + if not self.separate_files: + if self.bands and self.all_bands: + band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] + elif self.bands: + msg = ( + f"{self.__class__.__name__} is missing an `all_bands` attribute," + " so `bands` cannot be specified." + ) + raise AssertionError(msg) + else: + band_indexes = None + self._crs = cast(CRS, crs) self.res = cast(float, res) @@ -410,18 +422,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: data_list.append(self._merge_files(band_filepaths, query)) data = torch.cat(data_list) else: - if self.bands and self.all_bands: - band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] - elif self.bands: - msg = ( - f"{self.__class__.__name__} is missing an `all_bands` attribute," - " so `bands` cannot be specified." - ) - raise AssertionError(msg) - else: - band_indexes = None - - data = self._merge_files(filepaths, query, band_indexes) + data = self._merge_files(filepaths, query, self.band_indexes) sample = {"crs": self.crs, "bbox": query} if self.is_image: From 30b34e6ba716ca06f84700ef2f487ae92bd87d5c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 12:52:42 -0600 Subject: [PATCH 5/9] Store variable --- torchgeo/datasets/geo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index c06f7c3edbb..1b1773f382b 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -370,7 +370,7 @@ def __init__( if not self.separate_files: if self.bands and self.all_bands: - band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] + self.band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] elif self.bands: msg = ( f"{self.__class__.__name__} is missing an `all_bands` attribute," @@ -378,7 +378,7 @@ def __init__( ) raise AssertionError(msg) else: - band_indexes = None + self.band_indexes = None self._crs = cast(CRS, crs) self.res = cast(float, res) From ca505bc016b829e73c94c1d32c685ca419fa38e1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 12:53:27 -0600 Subject: [PATCH 6/9] all_bands -> default_bands --- torchgeo/datasets/landsat.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 48754d7f1e2..bbe43dcfe07 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -74,7 +74,7 @@ def __init__( Raises: FileNotFoundError: if no files are found in ``root`` """ - bands = bands or self.all_bands + bands = bands or self.default_bands self.filename_glob = self.filename_glob.format(bands[0]) super().__init__(root, crs, res, bands, transforms, cache) @@ -133,7 +133,7 @@ class Landsat1(Landsat): filename_glob = "LM01_*_{}.*" - all_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"] + default_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"] rgb_bands = ["SR_B6", "SR_B5", "SR_B4"] @@ -154,7 +154,7 @@ class Landsat4MSS(Landsat): filename_glob = "LM04_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"] + default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -163,7 +163,7 @@ class Landsat4TM(Landsat): filename_glob = "LT04_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] + default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -184,7 +184,16 @@ class Landsat7(Landsat): filename_glob = "LE07_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "SR_B8"] + default_bands = [ + "SR_B1", + "SR_B2", + "SR_B3", + "SR_B4", + "SR_B5", + "SR_B6", + "SR_B7", + "SR_B8", + ] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -193,7 +202,7 @@ class Landsat8(Landsat): filename_glob = "LC08_*_{}.*" - all_bands = [ + default_bands = [ "SR_B1", "SR_B2", "SR_B3", From da09b64098af6b4d6850ada79760d7e40c23af07 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 13:13:40 -0600 Subject: [PATCH 7/9] Simplify logic --- torchgeo/datasets/geo.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 1b1773f382b..aede7e6a405 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -369,16 +369,18 @@ def __init__( ) if not self.separate_files: - if self.bands and self.all_bands: - self.band_indexes = [self.all_bands.index(i) + 1 for i in self.bands] - elif self.bands: - msg = ( - f"{self.__class__.__name__} is missing an `all_bands` attribute," - " so `bands` cannot be specified." - ) - raise AssertionError(msg) - else: - self.band_indexes = None + self.band_indexes = None + if self.bands: + if self.all_bands: + self.band_indexes = [ + self.all_bands.index(i) + 1 for i in self.bands + ] + else: + msg = ( + f"{self.__class__.__name__} is missing an `all_bands` " + "attribute, so `bands` cannot be specified." + ) + raise AssertionError(msg) self._crs = cast(CRS, crs) self.res = cast(float, res) From f76fd192c14f2f7bf777a03d5ede043e744805b8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 14:55:09 -0600 Subject: [PATCH 8/9] fix mypy --- torchgeo/datasets/landsat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index bbe43dcfe07..d0d36480f54 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -4,7 +4,7 @@ """Landsat datasets.""" import abc -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -47,6 +47,7 @@ class Landsat(RasterDataset, abc.ABC): \. """ + default_bands: List[str] = [] separate_files = True def __init__( From b37cc99e463b1f9ca2b7ff9ae04f787df2fbc13a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 22 Feb 2023 14:57:13 -0600 Subject: [PATCH 9/9] Make default_bands required --- torchgeo/datasets/landsat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index d0d36480f54..c1bd6630a52 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -47,9 +47,13 @@ class Landsat(RasterDataset, abc.ABC): \. """ - default_bands: List[str] = [] separate_files = True + @property + @abc.abstractmethod + def default_bands(self) -> List[str]: + """Bands to load by default.""" + def __init__( self, root: str = "data",