diff --git a/docs/conf.py b/docs/conf.py index 55a15db9118..ec721bfff9d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) -import torchgeo # noqa: E402 +import torchgeo # -- Project information ----------------------------------------------------- diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 54107f4f8d0..99a981f9d2b 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -369,8 +369,8 @@ " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']" + " all_bands = ('B02', 'B03', 'B04', 'B08')\n", + " rgb_bands = ('B04', 'B03', 'B02')" ] }, { @@ -432,8 +432,8 @@ " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']\n", + " all_bands = ('B02', 'B03', 'B04', 'B08')\n", + " rgb_bands = ('B04', 'B03', 'B02')\n", "\n", " def plot(self, sample):\n", " # Find the correct band index order\n", diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index d93413bdd25..2708346482f 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -125,7 +125,7 @@ def filter_collection( if filtered.size().getInfo() == 0: raise ee.EEException( - f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501 + f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' ) return filtered diff --git a/experiments/ssl4eo/sample_ssl4eo.py b/experiments/ssl4eo/sample_ssl4eo.py index 68d1056df55..69f82f10283 100755 --- a/experiments/ssl4eo/sample_ssl4eo.py +++ b/experiments/ssl4eo/sample_ssl4eo.py @@ -47,7 +47,7 @@ def get_world_cities( download_root: str = 'world_cities', size: int = 10000 ) -> pd.DataFrame: - url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501 + url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' filename = 'worldcities.csv' download_and_extract_archive(url, download_root) cols = ['city', 'lat', 'lng', 'population'] diff --git a/pyproject.toml b/pyproject.toml index 646e77ff41a..ccc94cd766c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -268,7 +268,7 @@ quote-style = "single" skip-magic-trailing-comma = true [tool.ruff.lint] -extend-select = ["ANN", "D", "I", "NPY201", "UP"] +extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"] ignore = ["ANN101", "ANN102", "ANN401"] [tool.ruff.lint.per-file-ignores] diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py index 39d41f5d945..67323262452 100755 --- a/tests/data/dfc2022/data.py +++ b/tests/data/dfc2022/data.py @@ -19,36 +19,36 @@ train_set = [ { - 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 - 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501 + 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', + 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', + 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', }, { - 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 - 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501 + 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', + 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', + 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', }, ] unlabeled_set = [ { - 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', + 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', }, { - 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', + 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', }, ] val_set = [ { - 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', + 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', }, { - 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', + 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', }, ] diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py index 68fa8ffe397..e3197ddde12 100644 --- a/tests/data/seasonet/data.py +++ b/tests/data/seasonet/data.py @@ -112,7 +112,7 @@ # Compute checksums with open(archive, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'{season}: {repr(md5)}') + print(f'{season}: {md5!r}') # Write meta.csv with open('meta.csv', 'w') as f: @@ -121,7 +121,7 @@ # Compute checksums with open('meta.csv', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'meta.csv: {repr(md5)}') + print(f'meta.csv: {md5!r}') os.makedirs('splits', exist_ok=True) @@ -138,4 +138,4 @@ # Compute checksums with open('splits.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'splits: {repr(md5)}') + print(f'splits: {md5!r}') diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index 5354d325fed..3b2d4fc63f7 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -83,5 +83,5 @@ def test_invalid_query(self, dataset: EuroCrops) -> None: dataset[query] def test_integrity_error(self, dataset: EuroCrops) -> None: - dataset.zenodo_files = [('AA.zip', 'invalid')] + dataset.zenodo_files = (('AA.zip', 'invalid'),) assert not dataset._check_integrity() diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 28c50303b5e..8ff30531fee 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset): class CustomSentinelDataset(Sentinel2): - all_bands: list[str] = [] + all_bands: tuple[str, ...] = () separate_files = False @@ -356,7 +356,7 @@ def test_no_data(self, tmp_path: Path) -> None: def test_no_all_bands(self) -> None: root = os.path.join('tests', 'data', 'sentinel2') - bands = ['B04', 'B03', 'B02'] + bands = ('B04', 'B03', 'B02') transforms = nn.Identity() cache = True msg = ( diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index a0b2b61add7..b0c13b6075b 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -73,7 +73,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) @pytest.fixture def weights(self) -> WeightsEnum: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 40c4ee3d0d2..be8132c808b 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -103,13 +103,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -259,13 +259,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 035bdacc260..742cda3c371 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -97,13 +97,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_iobench.py b/tests/trainers/test_iobench.py index f67d19582ac..0fbde73bdc8 100644 --- a/tests/trainers/test_iobench.py +++ b/tests/trainers/test_iobench.py @@ -27,12 +27,12 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None: '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index 24b9fdb49cd..32c002dc573 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -63,7 +63,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) def test_version_warnings(self) -> None: with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'): diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 349fa6ab7d1..00c9da65321 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -84,13 +84,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -237,13 +237,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 5f096da75e7..ea4b0646521 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -108,13 +108,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index eacf2c0ed70..7e1292ab7c0 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -63,7 +63,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) def test_version_warnings(self) -> None: with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'): diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index f1ed2346164..1160f037366 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -37,7 +37,7 @@ def __init__( seasons = kwargs.get('seasons', 1) # Normalization only available for RGB dataset, defined here: - # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 + # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py if bands == SeasonalContrastS2.rgb_bands: _min = torch.tensor([3, 2, 0]) _max = torch.tensor([88, 103, 129]) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 64701cef519..99126698ffa 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,7 +3,7 @@ """So2Sat datamodule.""" -from typing import Any +from typing import Any, ClassVar import torch from torch import Generator, Tensor @@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule): "train" set and use the "test" set as the test set. """ - means_per_version: dict[str, Tensor] = { + means_per_version: ClassVar[dict[str, Tensor]] = { '2': torch.tensor( [ -0.00003591224260, @@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule): } means_per_version['3_culture_10'] = means_per_version['2'] - stds_per_version: dict[str, Tensor] = { + stds_per_version: ClassVar[dict[str, Tensor]] = { '2': torch.tensor( [ 0.17555201, diff --git a/torchgeo/datamodules/ssl4eo.py b/torchgeo/datamodules/ssl4eo.py index 6ad558dcf87..f0b1ecdee46 100644 --- a/torchgeo/datamodules/ssl4eo.py +++ b/torchgeo/datamodules/ssl4eo.py @@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule): .. versionadded:: 0.5 """ - # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 + # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 mean = torch.tensor(0) std = torch.tensor(10000) diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index 6dcd690c735..b46695957be 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset): * `scipy `_ to load the audio files to tensors """ - urls = [ + urls = ( 'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1', 'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1', - ] - filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip'] - md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31'] - directories = ['vision', 'sound'] - classes = [ + ) + filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip') + md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31') + directories = ('vision', 'sound') + classes: tuple[str, ...] = ( 'airport', 'beach', 'bridge', @@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset): 'sparse shrub land', 'sports land', 'train station', - ] + ) def __init__( self, @@ -119,7 +119,7 @@ def __init__( raise DatasetNotFoundError(self) self.files = self._load_files(self.root) - self.classes = sorted({f['cls'] for f in self.files}) + self.classes = tuple(sorted({f['cls'] for f in self.files})) self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __getitem__(self, index: int) -> dict[str, Tensor]: diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index 1b80e11555b..1ceaa9c9d3d 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): is_image = False - url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501 + url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson' diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index c1eae6de222..e40fca5eecf 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -7,7 +7,7 @@ import pathlib import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -90,8 +90,8 @@ class AgriFieldNet(RasterDataset): _(?PB[0-9A-Z]{2})_10m """ - rgb_bands = ['B04', 'B03', 'B02'] - all_bands = [ + rgb_bands = ('B04', 'B03', 'B02') + all_bands = ( 'B01', 'B02', 'B03', @@ -104,9 +104,9 @@ class AgriFieldNet(RasterDataset): 'B09', 'B11', 'B12', - ] + ) - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py index 3b0caf607ed..12b8c38141c 100644 --- a/torchgeo/datasets/airphen.py +++ b/torchgeo/datasets/airphen.py @@ -40,8 +40,8 @@ class Airphen(RasterDataset): # Each camera measures a custom set of spectral bands chosen at purchase time. # Hiphen offers 8 bands to choose from, sorted from short to long wavelength. - all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8'] - rgb_bands = ['B4', 'B3', 'B1'] + all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8') + rgb_bands = ('B4', 'B3', 'B1') def plot( self, diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 70bd76373d5..4dd1ae927de 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -147,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset): ) rgb_bands = ('B04', 'B03', 'B02') - classes = [ + classes = ( 'No data', 'Well-managed planatation', 'Poorly-managed planatation', @@ -155,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset): 'Residential', 'Background', 'Uncertain', - ] + ) # Same for all tiles tile_height = 1186 @@ -199,11 +199,13 @@ def __init__( # Calculate the indices that we will use over all tiles self.chips_metadata = [] - for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ - self.tile_height - self.chip_size + for y in [ + *list(range(0, self.tile_height - self.chip_size, stride)), + self.tile_height - self.chip_size, ]: - for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ - self.tile_width - self.chip_size + for x in [ + *list(range(0, self.tile_width - self.chip_size, stride)), + self.tile_width - self.chip_size, ]: self.chips_metadata.append((y, x)) diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 0f4c94565e1..38669cd6ff1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -7,6 +7,7 @@ import json import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset): * https://doi.org/10.1109/IGARSS.2019.8900532 - """ # noqa: E501 + """ - class_sets = { + class_sets: ClassVar[dict[int, list[str]]] = { 19: [ 'Urban fabric', 'Industrial or commercial units', @@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset): ], } - label_converter = { + label_converter: ClassVar[dict[int, int]] = { 0: 0, 1: 0, 2: 1, @@ -232,24 +233,24 @@ class BigEarthNet(NonGeoDataset): 42: 18, } - splits_metadata = { + splits_metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', 'filename': 'bigearthnet-train.csv', 'md5': '623e501b38ab7b12fe44f0083c00986d', }, 'val': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', 'filename': 'bigearthnet-val.csv', 'md5': '22efe8ed9cbd71fa10742ff7df2b7978', }, 'test': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', 'filename': 'bigearthnet-test.csv', 'md5': '697fb90677e30571b9ac7699b7e5b432', }, } - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 's1': { 'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz', 'md5': '94ced73440dea8c7b9645ee738c5a172', diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index dc757b96c4b..5d3fe4b631f 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -50,7 +50,7 @@ class BioMassters(NonGeoDataset): .. versionadded:: 0.5 """ - valid_splits = ['train', 'test'] + valid_splits = ('train', 'test') valid_sensors = ('S1', 'S2') metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index f91993763dd..ec6afdb34b7 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -30,7 +30,7 @@ class CanadianBuildingFootprints(VectorDataset): # https://github.com/microsoft/CanadianBuildingFootprints/issues/11 url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/' - provinces_territories = [ + provinces_territories = ( 'Alberta', 'BritishColumbia', 'Manitoba', @@ -44,8 +44,8 @@ class CanadianBuildingFootprints(VectorDataset): 'Quebec', 'Saskatchewan', 'YukonTerritory', - ] - md5s = [ + ) + md5s = ( '8b4190424e57bb0902bd8ecb95a9235b', 'fea05d6eb0006710729c675de63db839', 'adf11187362624d68f9c69aaa693c46f', @@ -59,7 +59,7 @@ class CanadianBuildingFootprints(VectorDataset): '9ff4417ae00354d39a0cf193c8df592c', 'a51078d8e60082c7d3a3818240da6dd5', 'c11f3bd914ecabd7cac2cb2871ec0261', - ] + ) def __init__( self, diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 25e42d2b030..07bd8193b79 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -6,7 +6,7 @@ import os import pathlib from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -38,7 +38,7 @@ class CDL(RasterDataset): If you use this dataset in your research, please cite it using the following format: * https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0 - """ # noqa: E501 + """ filename_glob = '*_30m_cdls.tif' filename_regex = r""" @@ -49,8 +49,8 @@ class CDL(RasterDataset): date_format = '%Y' is_image = False - url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501 - md5s = { + url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' + md5s: ClassVar[dict[int, str]] = { 2023: '8c7685d6278d50c554f934b16a6076b7', 2022: '754cf50670cdfee511937554785de3e6', 2021: '27606eab08fe975aa138baad3e5dfcd8', @@ -69,7 +69,7 @@ class CDL(RasterDataset): 2008: '0610f2f17ab60a9fbb3baeb7543993a4', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index 61eefbdf35c..ba773607a54 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -4,7 +4,8 @@ """ChaBuD dataset.""" import os -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset): .. versionadded:: 0.6 """ - all_bands = [ + all_bands = ( 'B01', 'B02', 'B03', @@ -66,10 +67,10 @@ class ChaBuD(NonGeoDataset): 'B09', 'B11', 'B12', - ] - rgb_bands = ['B04', 'B03', 'B02'] - folds = {'train': [1, 2, 3, 4], 'val': [0]} - url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501 + ) + rgb_bands = ('B04', 'B03', 'B02') + folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]} + url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' filename = 'train_eval.hdf5' md5 = '15d78fb825f9a81dad600db828d22c08' @@ -77,7 +78,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - bands: list[str] = all_bands, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index c3dd0ae88d0..46dbaefc806 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -9,7 +9,7 @@ import sys from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -39,7 +39,7 @@ class Chesapeake(RasterDataset, ABC): The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates characterization of the landscape and land change for and between discrete time - periods. The database was developed by the University of Vermont’s Spatial Analysis + periods. The database was developed by the University of Vermont's Spatial Analysis Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate @@ -83,7 +83,7 @@ def state(self) -> str: """State abbreviation.""" return self.__class__.__name__[-2:].lower() - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 11: (0, 92, 230, 255), 12: (0, 92, 230, 255), 13: (0, 92, 230, 255), @@ -255,7 +255,7 @@ def plot( class ChesapeakeDC(Chesapeake): """This subset of the dataset contains data only for Washington, D.C.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2013: '9f1df21afbb9d5c0fcf33af7f6750a7f', 2017: 'c45e4af2950e1c93ecd47b61af296d9b', } @@ -264,7 +264,7 @@ class ChesapeakeDC(Chesapeake): class ChesapeakeDE(Chesapeake): """This subset of the dataset contains data only for Delaware.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2013: '5850d96d897babba85610658aeb5951a', 2018: 'ee94c8efeae423d898677104117bdebc', } @@ -273,7 +273,7 @@ class ChesapeakeDE(Chesapeake): class ChesapeakeMD(Chesapeake): """This subset of the dataset contains data only for Maryland.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2013: '9c3ca5040668d15284c1bd64b7d6c7a0', 2018: '0647530edf8bec6e60f82760dcc7db9c', } @@ -282,7 +282,7 @@ class ChesapeakeMD(Chesapeake): class ChesapeakeNY(Chesapeake): """This subset of the dataset contains data only for New York.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2013: '38a29b721610ba661a7f8b6ec71a48b7', 2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3', } @@ -291,7 +291,7 @@ class ChesapeakeNY(Chesapeake): class ChesapeakePA(Chesapeake): """This subset of the dataset contains data only for Pennsylvania.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2013: '86febd603a120a49ef7d23ef486152a3', 2017: 'b11d92e4471e8cb887c790d488a338c1', } @@ -300,7 +300,7 @@ class ChesapeakePA(Chesapeake): class ChesapeakeVA(Chesapeake): """This subset of the dataset contains data only for Virginia.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2014: '49c9700c71854eebd00de24d8488eb7c', 2018: '51731c8b5632978bfd1df869ea10db5b', } @@ -309,7 +309,7 @@ class ChesapeakeVA(Chesapeake): class ChesapeakeWV(Chesapeake): """This subset of the dataset contains data only for West Virginia.""" - md5s = { + md5s: ClassVar[dict[int, str]] = { 2014: '32fea42fae147bd58a83e3ea6cccfb94', 2018: '80f25dcba72e39685ab33215c5d97292', } @@ -337,16 +337,16 @@ class ChesapeakeCVPR(GeoDataset): * https://doi.org/10.1109/cvpr.2019.01301 """ - subdatasets = ['base', 'prior_extension'] - urls = { - 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501 - 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501 + subdatasets = ('base', 'prior_extension') + urls: ClassVar[dict[str, str]] = { + 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', + 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', } - filenames = { + filenames: ClassVar[dict[str, str]] = { 'base': 'cvpr_chesapeake_landcover.zip', 'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip', } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'base': '1225ccbb9590e9396875f221e5031514', 'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a', } @@ -354,7 +354,7 @@ class ChesapeakeCVPR(GeoDataset): crs = CRS.from_epsg(3857) res = 1 - lc_cmap = { + lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 1: (0, 197, 255, 255), 2: (38, 115, 0, 255), @@ -374,7 +374,7 @@ class ChesapeakeCVPR(GeoDataset): ] ) - valid_layers = [ + valid_layers = ( 'naip-new', 'naip-old', 'landsat-leaf-on', @@ -383,8 +383,8 @@ class ChesapeakeCVPR(GeoDataset): 'lc', 'buildings', 'prior_from_cooccurrences_101_31_no_osm_no_buildings', - ] - states = ['de', 'md', 'va', 'wv', 'pa', 'ny'] + ) + states = ('de', 'md', 'va', 'wv', 'pa', 'ny') splits = ( [f'{state}-train' for state in states] + [f'{state}-val' for state in states] @@ -392,7 +392,7 @@ class ChesapeakeCVPR(GeoDataset): ) # these are used to check the integrity of the dataset - _files = [ + _files = ( 'de_1m_2013_extended-debuffered-test_tiles', 'de_1m_2013_extended-debuffered-train_tiles', 'de_1m_2013_extended-debuffered-val_tiles', @@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset): 'wv_1m_2014_extended-debuffered-train_tiles', 'wv_1m_2014_extended-debuffered-val_tiles', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif', - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501 - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif', - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', 'spatial_index.geojson', - ] + ) p_src_crs = pyproj.CRS('epsg:3857') - p_transformers = { + p_transformers: ClassVar[dict[str, CRS]] = { 'epsg:26917': pyproj.Transformer.from_crs( p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, @@ -511,7 +511,7 @@ def __init__( 'lc': row['properties']['lc'], 'nlcd': row['properties']['nlcd'], 'buildings': row['properties']['buildings'], - 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501 + 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, }, ) diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 7c7ed8b630c..e0ca0045e33 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -55,9 +56,9 @@ class CloudCoverDetection(NonGeoDataset): """ url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final' - all_bands = ['B02', 'B03', 'B04', 'B08'] - rgb_bands = ['B04', 'B03', 'B02'] - splits = {'train': 'public', 'test': 'private'} + all_bands = ('B02', 'B03', 'B04', 'B08') + rgb_bands = ('B04', 'B03', 'B02') + splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'} def __init__( self, diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index 91ddbf8c54a..61d5c4acafd 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -42,7 +42,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip' md5 = '3e7f9f23bf971c25e828b36e6c5496e3' - all_countries = [ + all_countries = ( 'AndamanAndNicobar', 'Angola', 'Anguilla', @@ -164,9 +164,9 @@ class CMSGlobalMangroveCanopy(RasterDataset): 'VirginIslandsUs', 'WallisAndFutuna', 'Yemen', - ] + ) - measurements = ['agb', 'hba95', 'hmax95'] + measurements = ('agb', 'hba95', 'hmax95') def __init__( self, diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index cae82e597fc..fa97fa87037 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -50,12 +50,12 @@ def base_url(self) -> str: @property @abc.abstractmethod - def filenames(self) -> list[str]: + def filenames(self) -> tuple[str, ...]: """List of files to download.""" @property @abc.abstractmethod - def md5s(self) -> list[str]: + def md5s(self) -> tuple[str, ...]: """List of MD5 checksums of files to download.""" @property @@ -239,7 +239,7 @@ class COWCCounting(COWC): base_url = ( 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/' ) - filenames = [ + filenames = ( 'COWC_train_list_64_class.txt.bz2', 'COWC_test_list_64_class.txt.bz2', 'COWC_Counting_Toronto_ISPRS.tbz', @@ -248,8 +248,8 @@ class COWCCounting(COWC): 'COWC_Counting_Vaihingen_ISPRS.tbz', 'COWC_Counting_Columbus_CSUAV_AFRL.tbz', 'COWC_Counting_Utah_AGRC.tbz', - ] - md5s = [ + ) + md5s = ( '187543d20fa6d591b8da51136e8ef8fb', '930cfd6e160a7b36db03146282178807', 'bc2613196dfa93e66d324ae43e7c1fdb', @@ -258,7 +258,7 @@ class COWCCounting(COWC): '4009c1e420566390746f5b4db02afdb9', 'daf8033c4e8ceebbf2c3cac3fabb8b10', '777ec107ed2a3d54597a739ce74f95ad', - ] + ) filename = 'COWC_{}_list_64_class.txt' @@ -268,7 +268,7 @@ class COWCDetection(COWC): base_url = ( 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/' ) - filenames = [ + filenames = ( 'COWC_train_list_detection.txt.bz2', 'COWC_test_list_detection.txt.bz2', 'COWC_Detection_Toronto_ISPRS.tbz', @@ -277,8 +277,8 @@ class COWCDetection(COWC): 'COWC_Detection_Vaihingen_ISPRS.tbz', 'COWC_Detection_Columbus_CSUAV_AFRL.tbz', 'COWC_Detection_Utah_AGRC.tbz', - ] - md5s = [ + ) + md5s = ( 'c954a5a3dac08c220b10cfbeec83893c', 'c6c2d0a78f12a2ad88b286b724a57c1a', '11af24f43b198b0f13c8e94814008a48', @@ -287,7 +287,7 @@ class COWCDetection(COWC): '23945d5b22455450a938382ccc2a8b27', 'f40522dc97bea41b10117d4a5b946a6f', '195da7c9443a939a468c9f232fd86ee3', - ] + ) filename = 'COWC_{}_list_detection.txt' diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 30f3a43f634..bb3e4b3f3c5 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -7,6 +7,7 @@ import json import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset): """ # https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py - all_bands = [ + all_bands = ( 'VV', 'VH', 'B2', @@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset): 'elevation', 'slope', 'NDVI', - ] - rgb_bands = ['B4', 'B3', 'B2'] + ) + rgb_bands = ('B4', 'B3', 'B2') features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1' labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1' - file_dict = { + file_dict: ClassVar[dict[str, dict[str, str]]] = { 'features': { 'url': features_url, 'filename': 'features.tar.gz', diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index 4e262d5266f..2248dab4292 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -65,8 +65,8 @@ class CV4AKenyaCropType(NonGeoDataset): """ url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge' - tiles = list(map(str, range(4))) - dates = [ + tiles = tuple(map(str, range(4))) + dates = ( '20190606', '20190701', '20190706', @@ -80,7 +80,7 @@ class CV4AKenyaCropType(NonGeoDataset): '20190924', '20191004', '20191103', - ] + ) all_bands = ( 'B01', 'B02', @@ -96,7 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset): 'B12', 'CLD', ) - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') # Same for all tiles tile_height = 3035 @@ -141,11 +141,13 @@ def __init__( # Calculate the indices that we will use over all tiles self.chips_metadata = [] for tile_index in range(len(self.tiles)): - for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ - self.tile_height - self.chip_size + for y in [ + *list(range(0, self.tile_height - self.chip_size, stride)), + self.tile_height - self.chip_size, ]: - for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ - self.tile_width - self.chip_size + for x in [ + *list(range(0, self.tile_width - self.chip_size, stride)), + self.tile_width - self.chip_size, ]: self.chips_metadata.append((tile_index, y, x)) diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index fcd9fb7bac2..51b82f9ff92 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -74,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset): $ unzip deepglobe2018-landcover-segmentation-traindataset.zip .. versionadded:: 0.3 - """ # noqa: E501 + """ filename = 'data.zip' data_root = 'data' md5 = 'f32684b0b2bf6f8d604cd359a399c061' - splits = ['train', 'test'] - classes = [ + splits = ('train', 'test') + classes = ( 'Urban land', 'Agriculture land', 'Rangeland', @@ -88,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset): 'Water', 'Barren land', 'Unknown', - ] - colormap = [ + ) + colormap = ( (0, 255, 255), (255, 255, 0), (255, 0, 255), @@ -97,7 +97,7 @@ class DeepGlobeLandCover(NonGeoDataset): (0, 0, 255), (255, 255, 255), (0, 0, 0), - ] + ) def __init__( self, @@ -246,12 +246,15 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 image2 = draw_semantic_segmentation_masks( - sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap + sample['image'], + sample['prediction'], + alpha=alpha, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 46d96e87d3f..edc79d8ae2b 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset): * https://doi.org/10.1007/s10994-020-05943-y .. versionadded:: 0.3 - """ # noqa: E501 + """ - classes = [ + classes = ( 'No information', 'Urban fabric', 'Industrial, commercial, public, military, private and transport units', @@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset): 'Wetlands', 'Water', 'Clouds and Shadows', - ] - colormap = [ + ) + colormap = ( '#231F20', '#DB5F57', '#DB9757', @@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset): '#579BDB', '#0062FF', '#231F20', - ] - metadata = { + ) + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'labeled_train.zip', 'md5': '2e87d6a218e466dd0566797d7298c7a9', diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 1d71a8940e6..5e6d4b80d99 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -6,7 +6,7 @@ import os import sys from collections.abc import Callable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -54,9 +54,9 @@ class EnviroAtlas(GeoDataset): crs = CRS.from_epsg(3857) res = 1 - valid_prior_layers = ['prior', 'prior_no_osm_no_buildings'] + valid_prior_layers = ('prior', 'prior_no_osm_no_buildings') - valid_layers = [ + valid_layers = ( 'naip', 'nlcd', 'roads', @@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset): 'waterbodies', 'buildings', 'lc', - ] + valid_prior_layers + *valid_prior_layers, + ) - cities = [ + cities = ( 'pittsburgh_pa-2010_1m', 'durham_nc-2012_1m', 'austin_tx-2012_1m', 'phoenix_az-2010_1m', - ] + ) splits = ( [f'{state}-train' for state in cities[:1]] + [f'{state}-val' for state in cities[:1]] @@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset): ) # these are used to check the integrity of the dataset - _files = [ + _files = ( 'austin_tx-2012_1m-test_tiles-debuffered', 'austin_tx-2012_1m-val5_tiles-debuffered', 'durham_nc-2012_1m-test_tiles-debuffered', @@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset): 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif', - 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501 - 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', 'spatial_index.geojson', - ] + ) p_src_crs = pyproj.CRS('epsg:3857') - p_transformers = { + p_transformers: ClassVar[dict[str, CRS]] = { 'epsg:26917': pyproj.Transformer.from_crs( p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, @@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset): dtype=np.uint8, ) - highres_classes = [ + highres_classes = ( 'Unclassified', 'Water', 'Impervious Surface', @@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset): 'Orchards', 'Woody Wetlands', 'Emergent Wetlands', - ] + ) highres_cmap = ListedColormap( [ [1.00000000, 1.00000000, 1.00000000], diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index ebf1d91f70b..7855c8bb3cf 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset): the ETCI competition. """ - bands = ['VV', 'VH'] - masks = ['flood', 'water_body'] - metadata = { + bands = ('VV', 'VH') + masks = ('flood', 'water_body') + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'train.zip', 'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f', diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 63a6f916526..1d65a55f60c 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -7,7 +7,7 @@ import os import pathlib from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -53,7 +53,7 @@ class EUDEM(RasterDataset): zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip' filename_regex = '(?P[eudem_v11]{10})_(?P[A-Z0-9]{6})' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591', 'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571', 'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709', diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index bf0f173d4c6..cb5bb2a5bc5 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -61,7 +61,7 @@ class EuroCrops(VectorDataset): date_format = '%Y' # Filename and md5 of files in this dataset on zenodo. - zenodo_files = [ + zenodo_files: tuple[tuple[str, str], ...] = ( ('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'), ('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'), ('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'), @@ -81,7 +81,7 @@ class EuroCrops(VectorDataset): # Year is unknown for Romania portion (ny = no year). # We skip since it is inconsistent with the rest of the data. # ("RO_ny.zip", "648e1504097765b4b7f825decc838882"), - ] + ) def __init__( self, diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 8292257404c..26c1c860cf8 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Sequence -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset): * https://ieeexplore.ieee.org/document/8519248 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' filename = 'EuroSATallBands.zip' md5 = '5ac12b3b2557aa56e1826e981e8e200e' @@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset): 'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif' ) - splits = ['train', 'val', 'test'] - split_urls = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501 - 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501 - 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501 + splits = ('train', 'val', 'test') + split_urls: ClassVar[dict[str, str]] = { + 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', + 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', + 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '908f142e73d6acdf3f482c5e80d851b1', 'val': '95de90f2aa998f70a3b2416bfe0687b4', 'test': '7ae5ab94471417b6e315763121e67c5f', @@ -93,7 +93,10 @@ class EuroSAT(NonGeoClassificationDataset): rgb_bands = ('B04', 'B03', 'B02') - BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands} + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { + 'all': all_band_names, + 'rgb': rgb_bands, + } def __init__( self, @@ -302,12 +305,12 @@ class EuroSATSpatial(EuroSAT): .. versionadded:: 0.6 """ - split_urls = { + split_urls: ClassVar[dict[str, str]] = { 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt', 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt', 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '7be3254be39f23ce4d4d144290c93292', 'val': 'acf392290050bb3df790dc8fc0ebf193', 'test': '5ec1733f9c16116bf0aa2d921fc613ef', @@ -325,16 +328,16 @@ class EuroSAT100(EuroSAT): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' filename = 'EuroSAT100.zip' md5 = 'c21c649ba747e86eda813407ef17d596' - split_urls = { - 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501 - 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501 - 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501 + split_urls: ClassVar[dict[str, str]] = { + 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', + 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', + 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '033d0c23e3a75e3fa79618b0e35fe1c7', 'val': '3e3f8b3c344182b8d126c4cc88f3f215', 'test': 'f908f151b950f270ad18e61153579794', diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 6d4337f57bc..d58968eaa19 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast from xml.etree.ElementTree import Element, parse import matplotlib.patches as patches @@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset): .. versionadded:: 0.2 """ - classes = { + classes: ClassVar[dict[str, dict[str, Any]]] = { 'Passenger Ship': {'id': 0, 'category': 'Ship'}, 'Motorboat': {'id': 1, 'category': 'Ship'}, 'Fishing Boat': {'id': 2, 'category': 'Ship'}, @@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset): 'Bridge': {'id': 36, 'category': 'Road'}, } - filename_glob = { + filename_glob: ClassVar[dict[str, str]] = { 'train': os.path.join('train', '**', 'images', '*.tif'), 'val': os.path.join('validation', 'images', '*.tif'), 'test': os.path.join('test', 'images', '*.tif'), } - directories = { + directories: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( os.path.join('train', 'part1', 'images'), os.path.join('train', 'part1', 'labelXml'), @@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset): os.path.join('validation', 'images'), os.path.join('validation', 'labelXml'), ), - 'test': (os.path.join('test', 'images')), + 'test': (os.path.join('test', 'images'),), } - paths = { + paths: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( os.path.join('train', 'part1', 'images.zip'), os.path.join('train', 'part1', 'labelXml.zip'), @@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset): os.path.join('test', 'images2.zip'), ), } - urls = { + urls: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( 'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf', 'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u', @@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset): 'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0', ), } - md5s = { + md5s: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( 'a460fe6b1b5b276bf856ce9ac72d6568', '80f833ff355f91445c92a0c0c1fa7414', diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index be40dfcf6d6..9370488f503 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset): md5 = 'a77b9a100d51167992ae8c51d26198a6' filename = 'FireRisk.zip' directory = 'FireRisk' - splits = ['train', 'val'] - classes = [ + splits = ('train', 'val') + classes = ( 'High', 'Low', 'Moderate', @@ -64,7 +64,7 @@ class FireRisk(NonGeoClassificationDataset): 'Very_High', 'Very_Low', 'Water', - ] + ) def __init__( self, diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 0eca812bb8b..33333b956ba 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -96,11 +96,8 @@ class ForestDamage(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ['other', 'H', 'LD', 'HD'] - url = ( - 'https://lilablobssc.blob.core.windows.net/larch-casebearer/' - 'Data_Set_Larch_Casebearer.zip' - ) + classes = ('other', 'H', 'LD', 'HD') + url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip' data_dir = 'Data_Set_Larch_Casebearer' md5 = '907815bcc739bff89496fac8f8ce63d7' diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f7362e9c692..68a7d853969 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -13,7 +13,7 @@ import sys import warnings from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import fiona.transform @@ -370,13 +370,13 @@ class RasterDataset(GeoDataset): separate_files = False #: Names of all available bands in the dataset - all_bands: list[str] = [] + all_bands: tuple[str, ...] = () #: Names of RGB bands in the dataset, used for plotting - rgb_bands: list[str] = [] + rgb_bands: tuple[str, ...] = () #: Color map for the dataset, used for plotting - cmap: dict[int, tuple[int, int, int, int]] = {} + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {} @property def dtype(self) -> torch.dtype: @@ -458,7 +458,7 @@ def __init__( # See if file has a color map if len(self.cmap) == 0: try: - self.cmap = src.colormap(1) + self.cmap = src.colormap(1) # type: ignore[misc] except ValueError: pass diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 078b83e6054..b42e6e58df6 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -66,8 +66,8 @@ class GID15(NonGeoDataset): md5 = '615682bf659c3ed981826c6122c10c83' filename = 'gid-15.zip' directory = 'GID' - splits = ['train', 'val', 'test'] - classes = [ + splits = ('train', 'val', 'test') + classes = ( 'background', 'industrial_land', 'urban_residential', @@ -84,7 +84,7 @@ class GID15(NonGeoDataset): 'river', 'lake', 'pond', - ] + ) def __init__( self, diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 007466738d2..e117bf361c6 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -7,7 +7,7 @@ import os import pathlib from collections.abc import Callable, Iterable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -73,9 +73,9 @@ class GlobBiomass(RasterDataset): is_image = False dtype = torch.float32 # pixelwise regression - measurements = ['agb', 'gsv'] + measurements = ('agb', 'gsv') - md5s = { + md5s: ClassVar[dict[str, str]] = { 'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6', 'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79', 'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097', diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 67a9db80e87..6157579eb1b 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast, overload +from typing import Any, ClassVar, cast, overload import fiona import matplotlib.pyplot as plt @@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset): .. versionadded:: 0.2 """ - classes = { + classes: ClassVar[dict[str, str]] = { 'ACPE': 'Acer pensylvanicum L.', 'ACRU': 'Acer rubrum L.', 'ACSA3': 'Acer saccharum Marshall', @@ -135,19 +135,22 @@ class IDTReeS(NonGeoDataset): 'ROPS': 'Robinia pseudoacacia L.', 'TSCA': 'Tsuga canadensis (L.) Carriere', } - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { - 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501 + 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', 'md5': '5ddfa76240b4bb6b4a7861d1d31c299c', 'filename': 'IDTREES_competition_train_v2.zip', }, 'test': { - 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501 + 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', 'md5': 'b108931c84a70f2a38a8234290131c9b', 'filename': 'IDTREES_competition_test_v2.zip', }, } - directories = {'train': ['train'], 'test': ['task1', 'task2']} + directories: ClassVar[dict[str, list[str]]] = { + 'train': ['train'], + 'test': ['task1', 'task2'], + } image_size = (200, 200) def __init__( diff --git a/torchgeo/datasets/iobench.py b/torchgeo/datasets/iobench.py index 80376df9579..608a9ccc17a 100644 --- a/torchgeo/datasets/iobench.py +++ b/torchgeo/datasets/iobench.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -40,9 +40,9 @@ class IOBench(IntersectionDataset): .. versionadded:: 0.6 """ - url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'original': 'e3a908a0fd1c05c1af2f4c65724d59b3', 'raw': 'e9603990441007ce7bba73bb8ba7d217', 'preprocessed': '9801f1240b238cb17525c865e413d1fd', @@ -54,7 +54,7 @@ def __init__( split: str = 'preprocessed', crs: CRS | None = None, res: float | None = None, - bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'], + bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'], classes: list[int] = [0], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 8adc3842c45..25830eaf3dc 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -8,7 +8,7 @@ import pathlib import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -43,8 +43,8 @@ class L7IrishImage(RasterDataset): """ date_format = '%Y%m%d' is_image = True - rgb_bands = ['B30', 'B20', 'B10'] - all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] + rgb_bands = ('B30', 'B20', 'B10') + all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80') class L7IrishMask(RasterDataset): @@ -59,7 +59,7 @@ class L7IrishMask(RasterDataset): _newmask2015\.TIF$ """ is_image = False - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud') ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map[64] = 1 ordinal_map[128] = 2 @@ -158,11 +158,11 @@ class L7Irish(IntersectionDataset): * https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109 .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'austral': '0a34770b992a62abeb88819feb192436', 'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082', 'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa', diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 4865ec932b4..318efa2476a 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -7,7 +7,7 @@ import os import pathlib from collections.abc import Callable, Iterable, Sequence -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -36,8 +36,8 @@ class L8BiomeImage(RasterDataset): """ date_format = '%Y%j' is_image = True - rgb_bands = ['B4', 'B3', 'B2'] - all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] + rgb_bands = ('B4', 'B3', 'B2') + all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11') class L8BiomeMask(RasterDataset): @@ -57,7 +57,7 @@ class L8BiomeMask(RasterDataset): """ date_format = '%Y%j' is_image = False - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud') ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map[64] = 1 ordinal_map[128] = 2 @@ -116,11 +116,11 @@ class L8Biome(IntersectionDataset): * https://doi.org/10.1016/j.rse.2017.03.026 .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'barren': '0eb691822d03dabd4f5ea8aadd0b41c3', 'forest': '4a5645596f6bb8cea44677f746ec676e', 'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c', diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 5273c03303c..d9a9643b524 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -9,7 +9,7 @@ import os from collections.abc import Callable from functools import lru_cache -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip' filename = 'landcover.ai.v1.zip' md5 = '3268c89070e8734b4e91d531c0617e03' - classes = ['Background', 'Building', 'Woodland', 'Water', 'Road'] - cmap = { + classes = ('Background', 'Building', 'Woodland', 'Water', 'Road') + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 1: (97, 74, 74, 255), 2: (38, 115, 0, 255), diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 48647f5d247..8fb33b7c9cc 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -33,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC): * `Surface Temperature `_ * `Surface Reflectance `_ * `U.S. Analysis Ready Data `_ - """ # noqa: E501 + """ # https://www.usgs.gov/landsat-missions/landsat-collection-2 filename_regex = r""" @@ -55,7 +55,7 @@ class Landsat(RasterDataset, abc.ABC): @property @abc.abstractmethod - def default_bands(self) -> list[str]: + def default_bands(self) -> tuple[str, ...]: """Bands to load by default.""" def __init__( @@ -145,8 +145,8 @@ class Landsat1(Landsat): filename_glob = 'LM01_*_{}.*' - default_bands = ['B4', 'B5', 'B6', 'B7'] - rgb_bands = ['B6', 'B5', 'B4'] + default_bands = ('B4', 'B5', 'B6', 'B7') + rgb_bands = ('B6', 'B5', 'B4') class Landsat2(Landsat1): @@ -166,8 +166,8 @@ class Landsat4MSS(Landsat): filename_glob = 'LM04_*_{}.*' - default_bands = ['B1', 'B2', 'B3', 'B4'] - rgb_bands = ['B3', 'B2', 'B1'] + default_bands = ('B1', 'B2', 'B3', 'B4') + rgb_bands = ('B3', 'B2', 'B1') class Landsat4TM(Landsat): @@ -175,8 +175,8 @@ class Landsat4TM(Landsat): filename_glob = 'LT04_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1') class Landsat5MSS(Landsat4MSS): @@ -196,8 +196,8 @@ class Landsat7(Landsat): filename_glob = 'LE07_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1') class Landsat8(Landsat): @@ -205,11 +205,11 @@ class Landsat8(Landsat): filename_glob = 'LC08_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2') class Landsat9(Landsat8): - """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501 + """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" filename_glob = 'LC09_*_{}.*' diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 9dbc68136db..e664365ab2d 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -7,6 +7,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -26,8 +27,8 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): .. versionadded:: 0.6 """ - splits: list[str] | dict[str, dict[str, str]] - directories = ['A', 'B', 'label'] + splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]] + directories = ('A', 'B', 'label') def __init__( self, @@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase): .. versionadded:: 0.6 """ - splits = { + splits: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-', 'filename': 'train.zip', @@ -336,7 +337,7 @@ class LEVIRCDPlus(LEVIRCDBase): md5 = '1adf156f628aa32fb2e8fe6cada16c04' filename = 'LEVIR-CD+.zip' directory = 'LEVIR-CD+' - splits = ['train', 'test'] + splits = ('train', 'test') def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 8c987548f90..0398410a674 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,7 +5,8 @@ import glob import os -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -57,10 +58,10 @@ class LoveDA(NonGeoDataset): .. versionadded:: 0.2 """ - scenes = ['urban', 'rural'] - splits = ['train', 'val', 'test'] + scenes = ('urban', 'rural') + splits = ('train', 'val', 'test') - info_dict = { + info_dict: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1', 'filename': 'Train.zip', @@ -78,7 +79,7 @@ class LoveDA(NonGeoDataset): }, } - classes = [ + classes = ( 'background', 'building', 'road', @@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset): 'forest', 'agriculture', 'no-data', - ] + ) def __init__( self, root: Path = 'data', split: str = 'train', - scene: list[str] = ['urban', 'rural'], + scene: Sequence[str] = ['urban', 'rural'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 66ce23ad70f..51289cf91cd 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -7,6 +7,7 @@ import shutil from collections import defaultdict from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -36,7 +37,7 @@ class MapInWild(NonGeoDataset): different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging Radiometer Suite NightTime Day/Night band. The dataset consists of 8144 - images with the shape of 1920 × 1920 pixels. The images are weakly annotated + images with the shape of 1920 x 1920 pixels. The images are weakly annotated from the World Database of Protected Areas (WDPA). Dataset features: @@ -54,9 +55,9 @@ class MapInWild(NonGeoDataset): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501 + url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' - modality_urls = { + modality_urls: ClassVar[dict[str, set[str]]] = { 'esa_wc': {'esa_wc/ESA_WC.zip'}, 'viirs': {'viirs/VIIRS.zip'}, 'mask': {'mask/mask.zip'}, @@ -72,7 +73,7 @@ class MapInWild(NonGeoDataset): 'split_IDs': {'split_IDs/split_IDs.csv'}, } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92', 'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4', 'mask.zip': '87c83a23a73998ad60d448d240b66225', @@ -91,9 +92,12 @@ class MapInWild(NonGeoDataset): 'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f', } - mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)} + mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = { + 1: (0, 153, 0), + 0: (255, 255, 255), + } - wc_cmap = { + wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = { 10: (0, 160, 0), 20: (150, 100, 0), 30: (255, 180, 0), diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 1938aefbafc..14e0ba7dc12 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset): .. versionadded:: 0.3 """ - multi_label_categories = [ + multi_label_categories = ( 'agriculture_land', 'airport_area', 'apartment', @@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset): 'wind_turbine', 'woodland', 'works', - ] + ) - multi_class_categories = [ + multi_class_categories = ( 'apartment', 'apron', 'bare_land', @@ -176,17 +176,17 @@ class MillionAID(NonGeoDataset): 'wastewater_plant', 'wind_turbine', 'works', - ] + ) - md5s = { + md5s: ClassVar[dict[str, str]] = { 'train': '1b40503cafa9b0601653ca36cd788852', 'test': '51a63ee3eeb1351889eacff349a983d8', } - filenames = {'train': 'train.zip', 'test': 'test.zip'} + filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'} - tasks = ['multi-class', 'multi-label'] - splits = ['train', 'test'] + tasks = ('multi-class', 'multi-label') + splits = ('train', 'test') def __init__( self, diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index d8185782367..326dccd6d72 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -45,8 +45,8 @@ class NAIP(RasterDataset): """ # Plotting - all_bands = ['R', 'G', 'B', 'NIR'] - rgb_bands = ['R', 'G', 'B'] + all_bands = ('R', 'G', 'B', 'NIR') + rgb_bands = ('R', 'G', 'B') def plot( self, diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 83163391735..96633b2e35b 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -4,7 +4,7 @@ """Northeastern China Crop Map Dataset.""" from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -57,23 +57,23 @@ class NCCM(RasterDataset): date_format = '%Y' is_image = False - urls = { + urls: ClassVar[dict[int, str]] = { 2019: 'https://figshare.com/ndownloader/files/25070540', 2018: 'https://figshare.com/ndownloader/files/25070624', 2017: 'https://figshare.com/ndownloader/files/25070582', } - md5s = { + md5s: ClassVar[dict[int, str]] = { 2019: '0d062bbd42e483fdc8239d22dba7020f', 2018: 'b3bb4894478d10786aa798fb11693ec1', 2017: 'd047fbe4a85341fa6248fd7e0badab6c', } - fnames = { + fnames: ClassVar[dict[int, str]] = { 2019: 'CDL2019_clip.tif', 2018: 'CDL2018_clip1.tif', 2017: 'CDL2017_clip.tif', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 255, 0, 255), 1: (255, 0, 0, 255), 2: (255, 255, 0, 255), diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index e7113b6709f..7f4498c76da 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -7,7 +7,7 @@ import os import pathlib from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -67,7 +67,7 @@ class NLCD(RasterDataset): * 2019: https://doi.org/10.5066/P9KZCM54 .. versionadded:: 0.5 - """ # noqa: E501 + """ filename_glob = 'nlcd_*_land_cover_l48_*.img' filename_regex = ( @@ -79,7 +79,7 @@ class NLCD(RasterDataset): url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip' - md5s = { + md5s: ClassVar[dict[int, str]] = { 2001: '538166a4d783204764e3df3b221fc4cd', 2006: '67454e7874a00294adb9442374d0c309', 2011: 'ea524c835d173658eeb6fa3c8e6b917b', @@ -87,7 +87,7 @@ class NLCD(RasterDataset): 2019: '82851c3f8105763b01c83b4a9e6f3961', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 11: (70, 107, 159, 255), 12: (209, 222, 248, 255), diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 9e341b6822d..f19f4c9a2ce 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -9,7 +9,7 @@ import pathlib import sys from collections.abc import Callable, Iterable -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import fiona.transform @@ -61,7 +61,7 @@ class OpenBuildings(VectorDataset): .. versionadded:: 0.3 """ - md5s = { + md5s: ClassVar[dict[str, str]] = { '025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17', '04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1', '05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670', diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 808ca93ab09..28f7714a7c6 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -50,7 +51,7 @@ class OSCD(NonGeoDataset): .. versionadded:: 0.2 """ - urls = { + urls: ClassVar[dict[str, str]] = { 'Onera Satellite Change Detection dataset - Images.zip': ( 'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download' ), @@ -61,7 +62,7 @@ class OSCD(NonGeoDataset): 'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download' ), } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'Onera Satellite Change Detection dataset - Images.zip': ( 'c50d4a2941da64e03a47ac4dec63d915' ), @@ -75,9 +76,9 @@ class OSCD(NonGeoDataset): zipfile_glob = '*Onera*.zip' filename_glob = '*Onera*' - splits = ['train', 'test'] + splits = ('train', 'test') - colormap = ['blue'] + colormap = ('blue',) all_bands = ( 'B01', @@ -319,7 +320,7 @@ def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': torch.from_numpy(rgb_img), sample['mask'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) return array diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index bdad515f66b..430a6330d89 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable, Sequence +from typing import ClassVar import fiona import matplotlib.pyplot as plt @@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset): .. versionadded:: 0.5 """ - classes = [ + classes = ( 'background', # all non-agricultural land 'meadow', 'soft_winter_wheat', @@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset): 'mixed_cereal', 'sorghum', 'void_label', # for parcels mostly outside their patch - ] - cmap = { + ) + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (174, 199, 232, 255), 2: (255, 127, 14, 255), @@ -118,7 +119,7 @@ class PASTIS(NonGeoDataset): filename = 'PASTIS-R.zip' url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1' md5 = '4887513d6c2d2b07fa935d325bd53e09' - prefix = { + prefix: ClassVar[dict[str, str]] = { 's2': os.path.join('DATA_S2', 'S2_'), 's1a': os.path.join('DATA_S1A', 'S1A_'), 's1d': os.path.join('DATA_S1D', 'S1D_'), @@ -232,7 +233,7 @@ def _load_semantic_targets(self, index: int) -> Tensor: Returns: the target mask """ - # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 + # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # even though the mask file is 3 bands, we just select the first band array = np.load(self.files[index]['semantic'])[0].astype(np.uint8) tensor = torch.from_numpy(array).long() diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 479ca3cf170..51f1ebd0441 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -54,12 +55,12 @@ class Potsdam2D(NonGeoDataset): * https://doi.org/10.5194/isprsannals-I-3-293-2012 .. versionadded:: 0.2 - """ # noqa: E501 + """ - filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip'] - md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db'] + filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip') + md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db') image_root = '4_Ortho_RGBIR' - splits = { + splits: ClassVar[dict[str, list[str]]] = { 'train': [ 'top_potsdam_2_10', 'top_potsdam_2_11', @@ -103,22 +104,22 @@ class Potsdam2D(NonGeoDataset): 'top_potsdam_7_13', ], } - classes = [ + classes = ( 'Clutter/background', 'Impervious surfaces', 'Building', 'Low Vegetation', 'Tree', 'Car', - ] - colormap = [ + ) + colormap = ( (255, 0, 0), (255, 255, 255), (0, 0, 255), (0, 255, 255), (0, 255, 0), (255, 255, 0), - ] + ) def __init__( self, @@ -257,7 +258,7 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 @@ -265,7 +266,7 @@ def plot( sample['image'][:3], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 3fd2501dd0f..811d79cff08 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -61,8 +61,12 @@ class QuakeSet(NonGeoDataset): filename = 'earthquakes.h5' url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5' md5 = '76fc7c76b7ca56f4844d852e175e1560' - splits = {'train': 'train', 'val': 'validation', 'test': 'test'} - classes = ['unaffected_area', 'earthquake_affected_area'] + splits: ClassVar[dict[str, str]] = { + 'train': 'train', + 'val': 'validation', + 'test': 'test', + } + classes = ('unaffected_area', 'earthquake_affected_area') def __init__( self, diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index bd28ab83f5d..930799ed205 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -56,7 +56,7 @@ class ReforesTree(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber'] + classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber') url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1' md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302' diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 0c8108a78b9..fd33b634fde 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -98,13 +98,13 @@ class RESISC45(NonGeoClassificationDataset): filename = 'NWPU-RESISC45.zip' directory = 'NWPU-RESISC45' - splits = ['train', 'val', 'test'] - split_urls = { + splits = ('train', 'val', 'test') + split_urls: ClassVar[dict[str, str]] = { 'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt', 'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt', 'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': 'b5a4c05a37de15e4ca886696a85c403e', 'val': 'a0770cee4c5ca20b8c32bbd61e114805', 'test': '3dda9e4988b47eb1de9f07993653eb08', diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 510039d8dcd..50459f5da9f 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -55,11 +56,11 @@ class RwandaFieldBoundary(NonGeoDataset): url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition' - splits = {'train': 57, 'test': 13} + splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13} dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') all_bands = ('B01', 'B02', 'B03', 'B04') rgb_bands = ('B03', 'B02', 'B01') - classes = ['No field-boundary', 'Field-boundary'] + classes = ('No field-boundary', 'Field-boundary') def __init__( self, diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 6d19fed7fcc..04cac92b36c 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -6,6 +6,7 @@ import os import random from collections.abc import Callable, Collection, Iterable +from typing import ClassVar import matplotlib.patches as mpatches import matplotlib.pyplot as plt @@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset): .. versionadded:: 0.5 """ - metadata = [ + metadata = ( { 'name': 'spring', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', 'md5': 'de4cdba7b6196aff624073991b187561', }, { 'name': 'summer', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', 'md5': '6a54d4e134d27ae4eb03f180ee100550', }, { 'name': 'fall', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', 'md5': '5f94920fe41a63c6bfbab7295f7d6b95', }, { 'name': 'winter', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', 'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48', }, { 'name': 'snow', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', 'md5': 'e1b300994143f99ebb03f51d6ab1cbe6', }, { 'name': 'splits', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', 'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed', }, { 'name': 'meta.csv', 'ext': '', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', 'md5': '43ea07974936a6bf47d989c32e16afe7', }, - ] - classes = [ + ) + classes = ( 'Continuous urban fabric', 'Discontinuous urban fabric', 'Industrial or commercial units', @@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset): 'Coastal lagoons', 'Estuaries', 'Sea and ocean', - ] - all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'} + ) + all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'}) all_bands = ('10m_RGB', '10m_IR', '20m', '60m') - band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2} - splits = ['train', 'val', 'test'] - cmap = { + band_nums: ClassVar[dict[str, int]] = { + '10m_RGB': 3, + '10m_IR': 1, + '20m': 6, + '60m': 2, + } + splits = ('train', 'val', 'test') + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (230, 000, 77, 255), 1: (255, 000, 000, 255), 2: (204, 77, 242, 255), @@ -331,7 +337,7 @@ def _load_image(self, index: int) -> Tensor: for band in self.bands: with rasterio.open(f'{path}_{band}.tif') as f: array = f.read( - out_shape=[f.count] + list(self.image_size), + out_shape=[f.count, *list(self.image_size)], out_dtype='int32', resampling=Resampling.bilinear, ) diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 74b8ebba54a..0d1df7f7bbe 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -5,7 +5,8 @@ import os import random -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -37,7 +38,7 @@ class SeasonalContrastS2(NonGeoDataset): * https://arxiv.org/pdf/2103.16607.pdf """ - all_bands = [ + all_bands = ( 'B1', 'B2', 'B3', @@ -50,10 +51,10 @@ class SeasonalContrastS2(NonGeoDataset): 'B9', 'B11', 'B12', - ] - rgb_bands = ['B4', 'B3', 'B2'] + ) + rgb_bands = ('B4', 'B3', 'B2') - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { '100k': { 'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1', 'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0', @@ -73,7 +74,7 @@ def __init__( root: Path = 'data', version: str = '100k', seasons: int = 1, - bands: list[str] = rgb_bands, + bands: Sequence[str] = rgb_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 8b8ee57803c..20183f06421 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset): or manually downloaded from https://dataserv.ub.tum.de/s/m1474000 and https://github.com/schmitt-muc/SEN12MS/tree/master/splits. This download will likely take several hours. - """ # noqa: E501 + """ - BAND_SETS: dict[str, tuple[str, ...]] = { + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { 'all': ( 'VV', 'VH', @@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset): 'B12', ) - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') - filenames = [ + filenames = ( 'ROIs1158_spring_lc.tar.gz', 'ROIs1158_spring_s1.tar.gz', 'ROIs1158_spring_s2.tar.gz', @@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset): 'ROIs2017_winter_s2.tar.gz', 'train_list.txt', 'test_list.txt', - ] - light_filenames = [ + ) + light_filenames = ( 'ROIs1158_spring', 'ROIs1868_summer', 'ROIs1970_fall', 'ROIs2017_winter', 'train_list.txt', 'test_list.txt', - ] - md5s = [ + ) + md5s = ( '6e2e8fa8b8cba77ddab49fd20ff5c37b', 'fba019bb27a08c1db96b31f718c34d79', 'd58af2c15a16f376eb3308dc9b685af2', @@ -161,7 +162,7 @@ class SEN12MS(NonGeoDataset): '3807545661288dcca312c9c538537b63', '0a68d4e1eb24f128fccdb930000b2546', 'c7faad064001e646445c4c634169484d', - ] + ) def __init__( self, diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 163c771a6b9..3dc5fc31771 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -137,7 +137,7 @@ class Sentinel1(Sentinel): \. """ date_format = '%Y%m%dT%H%M%S' - all_bands = ['HH', 'HV', 'VV', 'VH'] + all_bands = ('HH', 'HV', 'VV', 'VH') separate_files = True def __init__( @@ -277,7 +277,7 @@ class Sentinel2(Sentinel): date_format = '%Y%m%dT%H%M%S' # https://gisgeography.com/sentinel-2-bands-combinations/ - all_bands = [ + all_bands: tuple[str, ...] = ( 'B01', 'B02', 'B03', @@ -291,8 +291,8 @@ class Sentinel2(Sentinel): 'B10', 'B11', 'B12', - ] - rgb_bands = ['B04', 'B03', 'B02'] + ) + rgb_bands = ('B04', 'B03', 'B02') separate_files = True diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 8a8882f3598..a3a842561fa 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import numpy as np @@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501 - md5 = { + url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' + md5: ClassVar[dict[str, str]] = { 'forecast': 'f4f3509ddcc83a55c433be9db2e51077', 'nowcast': '0000761d403e45bb5f86c21d3c69aa80', } @@ -71,9 +71,9 @@ class SKIPPD(NonGeoDataset): data_file_name = '2017_2019_images_pv_processed_{}.hdf5' zipfile_name = '2017_2019_images_pv_processed_{}.zip' - valid_splits = ['trainval', 'test'] + valid_splits = ('trainval', 'test') - valid_tasks = ['nowcast', 'forecast'] + valid_tasks = ('nowcast', 'forecast') dateformat = '%m/%d/%Y, %H:%M:%S' diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index e90e89b4e34..4840a48e468 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Sequence -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset): This dataset requires the following additional library to be installed: * ``_ to load the dataset - """ # noqa: E501 + """ - versions = ['2', '3_random', '3_block', '3_culture_10'] - filenames_by_version = { + versions = ('2', '3_random', '3_block', '3_culture_10') + filenames_by_version: ClassVar[dict[str, dict[str, str]]] = { '2': { 'train': 'training.h5', 'validation': 'validation.h5', @@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset): 'test': 'culture_10/testing.h5', }, } - md5s_by_version = { + md5s_by_version: ClassVar[dict[str, dict[str, str]]] = { '2': { 'train': '702bc6a9368ebff4542d791e53469244', 'validation': '71cfa6795de3e22207229d06d6f8775d', @@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset): }, } - classes = [ + classes = ( 'Compact high rise', 'Compact mid rise', 'Compact low rise', @@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset): 'Bare rock or paved', 'Bare soil or sand', 'Water', - ] + ) all_s1_band_names = ( 'S1_B1', @@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset): ) all_band_names = all_s1_band_names + all_s2_band_names - rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02'] + rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02') - BAND_SETS = { + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { 'all': all_band_names, 's1': all_s1_band_names, 's2': all_s2_band_names, diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 688ce166924..7b79f44c8d2 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -6,8 +6,8 @@ import os import pathlib import re -from collections.abc import Callable, Iterable -from typing import Any, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -79,9 +79,9 @@ class SouthAfricaCropType(RasterDataset): _10m """ date_format = '%Y_%m_%d' - rgb_bands = ['B04', 'B03', 'B02'] - s1_bands = ['VH', 'VV'] - s2_bands = [ + rgb_bands = ('B04', 'B03', 'B02') + s1_bands = ('VH', 'VV') + s2_bands = ( 'B01', 'B02', 'B03', @@ -94,9 +94,9 @@ class SouthAfricaCropType(RasterDataset): 'B09', 'B11', 'B12', - ] - all_bands: list[str] = s1_bands + s2_bands - cmap = { + ) + all_bands = s1_bands + s2_bands + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), @@ -113,8 +113,8 @@ def __init__( self, paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, - classes: list[int] = list(cmap.keys()), - bands: list[str] = s2_bands, + classes: Sequence[int] = list(cmap.keys()), + bands: Sequence[str] = s2_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, ) -> None: diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 1ee8d7cb79e..3ec4b559472 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -5,7 +5,7 @@ import pathlib from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -47,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset): is_image = False url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif' - md5s = { + md5s: ClassVar[dict[int, str]] = { 2021: 'edff3ada13a1a9910d1fe844d28ae4f', 2020: '0709dec807f576c9707c8c7e183db31', 2019: '441836493bbcd5e123cff579a58f5a4f', diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 0ff02c55da4..21a31657f58 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -8,7 +8,7 @@ import re from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import fiona import matplotlib.pyplot as plt @@ -55,9 +55,9 @@ class SpaceNet(NonGeoDataset, ABC): image_glob = '*.tif' mask_glob = '*.geojson' file_regex = r'_img(\d+)\.' - chip_size: dict[str, tuple[int, int]] = {} + chip_size: ClassVar[dict[str, tuple[int, int]]] = {} - cities = { + cities: ClassVar[dict[int, str]] = { 1: 'Rio', 2: 'Vegas', 3: 'Paris', @@ -98,7 +98,7 @@ def valid_images(self) -> dict[str, list[str]]: @property @abstractmethod - def valid_masks(self) -> list[str]: + def valid_masks(self) -> tuple[str, ...]: """List of valid masks.""" def __init__( @@ -426,7 +426,7 @@ class SpaceNet1(SpaceNet): directory_glob = '{product}' dataset_id = 'SN1_buildings' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 1: [ 'SN1_buildings_train_AOI_1_Rio_3band.tar.gz', @@ -441,7 +441,7 @@ class SpaceNet1(SpaceNet): ] }, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 1: [ '279e334a2120ecac70439ea246174516', @@ -453,10 +453,16 @@ class SpaceNet1(SpaceNet): 1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8'] }, } - valid_aois = {'train': [1], 'test': [1]} - valid_images = {'train': ['3band', '8band'], 'test': ['3band', '8band']} - valid_masks = ['geojson'] - chip_size = {'3band': (406, 439), '8band': (102, 110)} + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['3band', '8band'], + 'test': ['3band', '8band'], + } + valid_masks = ('geojson',) + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + '3band': (406, 439), + '8band': (102, 110), + } class SpaceNet2(SpaceNet): @@ -522,7 +528,7 @@ class SpaceNet2(SpaceNet): """ dataset_id = 'SN2_buildings' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'], 3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'], @@ -536,7 +542,7 @@ class SpaceNet2(SpaceNet): 5: ['AOI_5_Khartoum_Test_public.tar.gz'], }, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 2: ['307da318bc43aaf9481828f92eda9126'], 3: ['4db469e3e4e7bf025368ad730aec0888'], @@ -550,13 +556,16 @@ class SpaceNet2(SpaceNet): 5: ['037d7be10530f0dd1c43d4ef79f3236e'], }, } - valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = { + 'train': [2, 3, 4, 5], + 'test': [2, 3, 4, 5], + } + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], } - valid_masks = [os.path.join('geojson', 'buildings')] - chip_size = {'MUL': (163, 163)} + valid_masks = (os.path.join('geojson', 'buildings'),) + chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)} class SpaceNet3(SpaceNet): @@ -624,7 +633,7 @@ class SpaceNet3(SpaceNet): """ dataset_id = 'SN3_roads' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 2: [ 'SN3_roads_train_AOI_2_Vegas.tar.gz', @@ -650,7 +659,7 @@ class SpaceNet3(SpaceNet): 5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'], }, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'], 3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'], @@ -664,12 +673,15 @@ class SpaceNet3(SpaceNet): 5: ['f367c79fa0fc1d38e63a0fdd065ed957'], }, } - valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = { + 'train': [2, 3, 4, 5], + 'test': [2, 3, 4, 5], + } + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'], 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], } - valid_masks = ['geojson_roads', 'geojson_roads_speed'] + valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed') class SpaceNet4(SpaceNet): @@ -708,7 +720,7 @@ class SpaceNet4(SpaceNet): directory_glob = os.path.join('**', '{product}') file_regex = r'_(\d+_\d+)\.' dataset_id = 'SN4_buildings' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 6: [ 'Atlanta_nadir7_catid_1030010003D22F00.tar.gz', @@ -743,7 +755,7 @@ class SpaceNet4(SpaceNet): }, 'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']}, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 6: [ 'd41ab6ec087b07e1e046c55d1fa5754b', @@ -778,12 +790,12 @@ class SpaceNet4(SpaceNet): }, 'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']}, } - valid_aois = {'train': [6], 'test': [6]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]} + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['MS', 'PAN', 'Pan-Sharpen'], 'test': ['MS', 'PAN', 'Pan-Sharpen'], } - valid_masks = [os.path.join('geojson', 'spacenet-buildings')] + valid_masks = (os.path.join('geojson', 'spacenet-buildings'),) class SpaceNet5(SpaceNet3): @@ -850,26 +862,26 @@ class SpaceNet5(SpaceNet3): file_regex = r'_chip(\d+)\.' dataset_id = 'SN5_roads' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'], 8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'], }, 'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']}, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 7: ['03082d01081a6d8df2bc5a9645148d2a'], 8: ['1ee20ba781da6cb7696eef9a95a5bdcc'], }, 'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']}, } - valid_aois = {'train': [7, 8], 'test': [9]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]} + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], 'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], } - valid_masks = ['geojson_roads_speed'] + valid_masks = ('geojson_roads_speed',) class SpaceNet6(SpaceNet): @@ -937,20 +949,20 @@ class SpaceNet6(SpaceNet): file_regex = r'_tile_(\d+)\.' dataset_id = 'SN6_buildings' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']}, 'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']}, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']}, 'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']}, } - valid_aois = {'train': [11], 'test': [11]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]} + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'], 'test': ['SAR-Intensity'], } - valid_masks = ['geojson_buildings'] + valid_masks = ('geojson_buildings',) class SpaceNet7(SpaceNet): @@ -958,7 +970,7 @@ class SpaceNet7(SpaceNet): `SpaceNet 7 `_ is a dataset which consist of medium resolution (4.0m) satellite imagery mosaics acquired from - Planet Labs’ Dove constellation between 2017 and 2020. It includes ≈ 24 + Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24 images (one per month) covering > 100 unique geographies, and comprises > 40,000 km2 of imagery and exhaustive polygon labels of building footprints therein, totaling over 11M individual annotations. @@ -993,18 +1005,24 @@ class SpaceNet7(SpaceNet): mask_glob = '*_Buildings.geojson' file_regex = r'global_monthly_(\d+.*\d+)' dataset_id = 'SN7_buildings' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': {0: ['SN7_buildings_train.tar.gz']}, 'test': {0: ['SN7_buildings_test_public.tar.gz']}, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']}, 'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']}, } - valid_aois = {'train': [0], 'test': [0]} - valid_images = {'train': ['images', 'images_masked'], 'test': ['images_masked']} - valid_masks = ['labels', 'labels_match', 'labels_match_pix'] - chip_size = {'images': (1024, 1024), 'images_masked': (1024, 1024)} + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['images', 'images_masked'], + 'test': ['images_masked'], + } + valid_masks = ('labels', 'labels_match', 'labels_match_pix') + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + 'images': (1024, 1024), + 'images_masked': (1024, 1024), + } class SpaceNet8(SpaceNet): @@ -1024,7 +1042,7 @@ class SpaceNet8(SpaceNet): directory_glob = '{product}' file_regex = r'(\d+_\d+_\d+)\.' dataset_id = 'SN8_floods' - tarballs = { + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 0: [ 'Germany_Training_Public.tar.gz', @@ -1033,16 +1051,19 @@ class SpaceNet8(SpaceNet): }, 'test': {0: ['Louisiana-West_Test_Public.tar.gz']}, } - md5s = { + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { 'train': { 0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95'] }, 'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']}, } - valid_aois = {'train': [0], 'test': [0]} - valid_images = { + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]} + valid_images: ClassVar[dict[str, list[str]]] = { 'train': ['PRE-event', 'POST-event'], 'test': ['PRE-event', 'POST-event'], } - valid_masks = ['annotations'] - chip_size = {'PRE-event': (1300, 1300), 'POST-event': (1300, 1300)} + valid_masks = ('annotations',) + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + 'PRE-event': (1300, 1300), + 'POST-event': (1300, 1300), + } diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index a087cbf68dc..b2840afb865 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -7,7 +7,7 @@ import os import random from collections.abc import Callable -from typing import TypedDict +from typing import ClassVar, TypedDict import matplotlib.pyplot as plt import numpy as np @@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset): * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html .. versionadded:: 0.5 - """ # noqa: E501 + """ class _Metadata(TypedDict): num_bands: int rgb_bands: list[int] - metadata: dict[str, _Metadata] = { + metadata: ClassVar[dict[str, _Metadata]] = { 'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]}, 'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]}, 'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]}, @@ -107,8 +107,8 @@ class _Metadata(TypedDict): 'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]}, } - url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501 - checksums = { + url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' + checksums: ClassVar[dict[str, dict[str, str]]] = { 'tm_toa': { 'aa': '553795b8d73aa253445b1e67c5b81f11', 'ab': 'e9e0739b5171b37d16086cb89ab370e8', @@ -357,7 +357,7 @@ class _Metadata(TypedDict): md5: str bands: list[str] - metadata: dict[str, _Metadata] = { + metadata: ClassVar[dict[str, _Metadata]] = { 's1': { 'filename': 's1.tar.gz', 'md5': '51ee23b33eb0a2f920bda25225072f3a', diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 3f16bf33b07..fc9bb3883d3 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset): * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' - valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr'] - valid_products = ['cdl', 'nlcd'] - valid_splits = ['train', 'val', 'test'] + valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr') + valid_products = ('cdl', 'nlcd') + valid_splits = ('train', 'val', 'test') image_root = 'ssl4eo_l_{}_benchmark' - img_md5s = { + img_md5s: ClassVar[dict[str, str]] = { 'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15', 'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f', 'etm_sr': '34a24fa89a801654f8d01e054662c8cd', @@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15', } - mask_dir_dict = { + mask_dir_dict: ClassVar[dict[str, str]] = { 'tm_toa': 'ssl4eo_l_tm_{}', 'etm_toa': 'ssl4eo_l_etm_{}', 'etm_sr': 'ssl4eo_l_etm_{}', 'oli_tirs_toa': 'ssl4eo_l_oli_{}', 'oli_sr': 'ssl4eo_l_oli_{}', } - mask_md5s = { + mask_md5s: ClassVar[dict[str, dict[str, str]]] = { 'tm': { 'cdl': '3d676770ffb56c7e222a7192a652a846', 'nlcd': '261149d7614fcfdcb3be368eefa825c7', @@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset): }, } - year_dict = { + year_dict: ClassVar[dict[str, int]] = { 'tm_toa': 2011, 'etm_toa': 2019, 'etm_sr': 2019, @@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': 2019, } - rgb_indices = { + rgb_indices: ClassVar[dict[str, list[int]]] = { 'tm_toa': [2, 1, 0], 'etm_toa': [2, 1, 0], 'etm_sr': [2, 1, 0], @@ -101,9 +102,12 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': [3, 2, 1], } - split_percentages = [0.7, 0.15, 0.15] + split_percentages = (0.7, 0.15, 0.15) - cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap} + cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = { + 'nlcd': NLCD.cmap, + 'cdl': CDL.cmap, + } def __init__( self, diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 604c17bc70e..eec9be57ab3 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -45,17 +45,17 @@ class SustainBenchCropYield(NonGeoDataset): * https://doi.org/10.1609/aaai.v31i1.11172 .. versionadded:: 0.5 - """ # noqa: E501 + """ - valid_countries = ['usa', 'brazil', 'argentina'] + valid_countries = ('usa', 'brazil', 'argentina') md5 = '362bad07b51a1264172b8376b39d1fc9' - url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501 + url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' dir = 'soybeans' - valid_splits = ['train', 'dev', 'test'] + valid_splits = ('train', 'dev', 'test') def __init__( self, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 7045ee558a8..5527a7ed133 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset): * https://dl.acm.org/doi/10.1145/1869790.1869829 """ - url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' filename = 'UCMerced_LandUse.zip' md5 = '5b7ec56793786b6dc8a908e8854ac0e4' base_dir = os.path.join('UCMerced_LandUse', 'Images') - splits = ['train', 'val', 'test'] - split_urls = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501 - 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501 - 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501 + splits = ('train', 'val', 'test') + split_urls: ClassVar[dict[str, str]] = { + 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', + 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', + 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': 'f2fb12eb2210cfb53f93f063a35ff374', 'val': '11ecabfc52782e5ea6a9c7c0d263aca0', 'test': '046aff88472d8fc07c4678d03749e28d', diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 42443c56d2d..db13059bfa7 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -49,12 +50,12 @@ class USAVars(NonGeoDataset): .. versionadded:: 0.3 """ - data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501 + data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' dirname = 'uar' md5 = '677e89fd20e5dd0fe4d29b61827c2456' - label_urls = { + label_urls: ClassVar[dict[str, str]] = { 'housing': data_url.format('housing.csv'), 'income': data_url.format('income.csv'), 'roads': data_url.format('roads.csv'), @@ -64,7 +65,7 @@ class USAVars(NonGeoDataset): 'treecover': data_url.format('treecover.csv'), } - split_metadata = { + split_metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'url': data_url.format('train_split.txt'), 'filename': 'train_split.txt', @@ -82,7 +83,7 @@ class USAVars(NonGeoDataset): }, } - ALL_LABELS = ['treecover', 'elevation', 'population'] + ALL_LABELS = ('treecover', 'elevation', 'population') def __init__( self, diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index aa8da3a8c82..73861011a7b 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -86,11 +86,11 @@ def __post_init__(self) -> None: # https://github.com/PyCQA/pydocstyle/issues/525 @overload - def __getitem__(self, key: int) -> float: # noqa: D105 + def __getitem__(self, key: int) -> float: pass @overload - def __getitem__(self, key: slice) -> list[float]: # noqa: D105 + def __getitem__(self, key: slice) -> list[float]: pass def __getitem__(self, key: int | slice) -> float | list[float]: @@ -289,7 +289,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[byt The completed process. """ kwargs['check'] = True - return subprocess.run((self.name,) + args, **kwargs) + return subprocess.run((self.name, *args), **kwargs) def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: @@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks( def rgb_to_mask( - rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]] + rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]] ) -> np.typing.NDArray[np.uint8]: """Converts an RGB colormap mask to a integer mask. diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 305eb950197..2c671ca27ac 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -55,15 +56,15 @@ class Vaihingen2D(NonGeoDataset): * https://doi.org/10.5194/isprsannals-I-3-293-2012 .. versionadded:: 0.2 - """ # noqa: E501 + """ - filenames = [ + filenames = ( 'ISPRS_semantic_labeling_Vaihingen.zip', 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', - ] - md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277'] + ) + md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277') image_root = 'top' - splits = { + splits: ClassVar[dict[str, list[str]]] = { 'train': [ 'top_mosaic_09cm_area1.tif', 'top_mosaic_09cm_area11.tif', @@ -102,22 +103,22 @@ class Vaihingen2D(NonGeoDataset): 'top_mosaic_09cm_area29.tif', ], } - classes = [ + classes = ( 'Clutter/background', 'Impervious surfaces', 'Building', 'Low Vegetation', 'Tree', 'Car', - ] - colormap = [ + ) + colormap = ( (255, 0, 0), (255, 255, 255), (0, 0, 255), (0, 255, 255), (0, 255, 0), (255, 255, 0), - ] + ) def __init__( self, @@ -258,7 +259,7 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 @@ -266,7 +267,7 @@ def plot( sample['image'][:3], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index adce27494af..4840a788604 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import numpy as np @@ -158,18 +158,18 @@ class VHR10(NonGeoDataset): ``annotations.json`` file for the "positive" image set """ - image_meta = { + image_meta: ClassVar[dict[str, str]] = { 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip', 'filename': 'NWPU VHR-10 dataset.zip', 'md5': '6add6751469c12dd8c8d6223064c6c4d', } - target_meta = { + target_meta: ClassVar[dict[str, str]] = { 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json', 'filename': 'annotations.json', 'md5': '7c76ec50c17a61bb0514050d20f22c08', } - categories = [ + categories = ( 'background', 'airplane', 'ships', @@ -181,7 +181,7 @@ class VHR10(NonGeoDataset): 'harbor', 'bridge', 'vehicle', - ] + ) def __init__( self, diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index 136689894f1..fe51f6ade8f 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -6,7 +6,7 @@ import glob import json import os -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import pandas as pd @@ -53,7 +53,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): label_name = 'percent(t)' - all_variable_names = [ + all_variable_names = ( # "date", 'slope(t)', 'elevation(t)', @@ -193,12 +193,12 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): 'vh_vv(t-3)', 'lat', 'lon', - ] + ) def __init__( self, root: Path = 'data', - input_features: list[str] = all_variable_names, + input_features: Iterable[str] = all_variable_names, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, ) -> None: @@ -273,7 +273,7 @@ def _load_data(self) -> pd.DataFrame: data_rows.append(data_dict) df = pd.DataFrame(data_rows) - df = df[self.input_features + [self.label_name]] + df = df[[*self.input_features, self.label_name]] return df def _verify(self) -> None: diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 0cb66cba9f3..a7f6a36456a 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -54,7 +55,7 @@ class XView2(NonGeoDataset): .. versionadded:: 0.2 """ - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'train_images_labels_targets.tar.gz', 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', @@ -66,8 +67,8 @@ class XView2(NonGeoDataset): 'directory': 'test', }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed') + colormap = ('green', 'blue', 'orange', 'red') def __init__( self, @@ -242,10 +243,16 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample['image'][0], + sample['mask'][0], + alpha=alpha, + colors=list(self.colormap), ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample['image'][1], + sample['mask'][1], + alpha=alpha, + colors=list(self.colormap), ) if 'prediction' in sample: # NOTE: this assumes predictions are made for post ncols += 1 @@ -253,7 +260,7 @@ def plot( sample['image'][1], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index 394721b5237..2928dc58a70 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -52,15 +52,15 @@ class ZueriCrop(NonGeoDataset): * `h5py `_ to load the dataset """ - urls = [ + urls = ( 'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download', - 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501 - ] - md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] - filenames = ['ZueriCrop.hdf5', 'labels.csv'] + 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', + ) + md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b') + filenames = ('ZueriCrop.hdf5', 'labels.csv') band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12') - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') def __init__( self, diff --git a/torchgeo/main.py b/torchgeo/main.py index b403d4fa50c..48e84b6e8cf 100644 --- a/torchgeo/main.py +++ b/torchgeo/main.py @@ -8,7 +8,7 @@ from lightning.pytorch.cli import ArgsType, LightningCLI # Allows classes to be referenced using only the class name -import torchgeo.datamodules # noqa: F401 +import torchgeo.datamodules import torchgeo.trainers # noqa: F401 from torchgeo.datamodules import BaseDataModule from torchgeo.trainers import BaseTask diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 9e214d9d04e..1caf32a47a4 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -8,7 +8,7 @@ * https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/ * https://pytorch.org/vision/stable/models.html * https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py -""" # noqa: E501 +""" from collections.abc import Callable from typing import Any diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index 32f0be01a61..e82fbd574bd 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -384,7 +384,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', @@ -403,7 +403,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index 82545afa339..079b1222867 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -140,7 +140,7 @@ def _normalize( a numpy array of size (N, C, H, W) containing the normalized patches .. versionadded:: 0.5 - """ # noqa: E501 + """ n_patches = patches.shape[0] orig_shape = patches.shape patches = patches.reshape(patches.shape[0], -1) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 05d95ceb8f1..c62c71dea6a 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -11,8 +11,8 @@ from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = K.AugmentationSequential( K.Resize(256), @@ -22,7 +22,7 @@ ) # Normalization only available for RGB dataset, defined here: -# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 +# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py _min = torch.tensor([3, 2, 0]) _max = torch.tensor([88, 103, 129]) _mean = torch.tensor([0.485, 0.456, 0.406]) @@ -37,7 +37,7 @@ ) # Normalization only available for RGB dataset, defined here: -# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501 +# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 _mean = torch.tensor([0.485, 0.456, 0.406]) _std = torch.tensor([0.229, 0.224, 0.225]) _gassl_transforms = K.AugmentationSequential( @@ -47,7 +47,7 @@ data_keys=None, ) -# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 +# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 _ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), @@ -70,7 +70,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] """ LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -83,7 +83,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -96,7 +96,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -109,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -122,7 +122,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -135,7 +135,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -148,7 +148,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -161,7 +161,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -174,7 +174,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -187,7 +187,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -200,7 +200,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -213,7 +213,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_RGB_MOCO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -226,7 +226,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_RGB_SECO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', transforms=_seco_transforms, meta={ 'dataset': 'SeCo Dataset', @@ -249,7 +249,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] """ FMOW_RGB_GASSL = Weights( - url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', transforms=_gassl_transforms, meta={ 'dataset': 'fMoW Dataset', @@ -262,7 +262,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -275,7 +275,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -288,7 +288,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -301,7 +301,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -314,7 +314,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -327,7 +327,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -340,7 +340,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -353,7 +353,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -366,7 +366,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -379,7 +379,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -392,7 +392,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL1_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -405,7 +405,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_ALL_DINO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -418,7 +418,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -431,7 +431,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_RGB_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -444,7 +444,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_RGB_SECO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', transforms=_seco_transforms, meta={ 'dataset': 'SeCo Dataset', diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index 21df8dedd91..f29c2ffabcb 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -12,20 +12,20 @@ from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum -# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 +# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 -# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501 +# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. +# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. _satlas_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None ) # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 -# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501 +# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. +# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). _std = torch.tensor( [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] -) # noqa: E501 +) _mean = torch.zeros_like(_std) _sentinel2_ms_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=_mean, std=_std), @@ -33,7 +33,7 @@ data_keys=None, ) -# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 +# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). _landsat_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), @@ -56,7 +56,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] """ NAIP_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', transforms=_satlas_transforms, meta={ 'dataset': 'Satlas', @@ -68,7 +68,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', transforms=_satlas_transforms, meta={ 'dataset': 'Satlas', @@ -80,7 +80,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_MS_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', transforms=_sentinel2_ms_satlas_transforms, meta={ 'dataset': 'Satlas', @@ -93,7 +93,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL1_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', transforms=_satlas_transforms, meta={ 'dataset': 'Satlas', @@ -106,7 +106,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', transforms=_landsat_satlas_transforms, meta={ 'dataset': 'Satlas', @@ -126,7 +126,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] 'B09', 'B10', 'B11', - ], # noqa: E501 + ], }, ) diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index a81ac13d48a..1878883f484 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -11,8 +11,8 @@ from timm.models.vision_transformer import VisionTransformer from torchvision.models._api import Weights, WeightsEnum -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = K.AugmentationSequential( K.Resize(256), @@ -21,7 +21,7 @@ data_keys=None, ) -# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 +# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 _ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), @@ -44,7 +44,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] """ LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -57,7 +57,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -70,7 +70,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -83,7 +83,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -96,7 +96,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -109,7 +109,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -122,7 +122,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -135,7 +135,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -148,7 +148,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -161,7 +161,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -174,7 +174,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_ALL_DINO = Weights( - url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -187,7 +187,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 87484e75730..b2210eb2518 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -53,7 +53,7 @@ def __init__( else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501 + self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Perform augmentations and update data dict.