diff --git a/docs/conf.py b/docs/conf.py
index 55a15db9118..ec721bfff9d 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -19,7 +19,7 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))
-import torchgeo # noqa: E402
+import torchgeo
# -- Project information -----------------------------------------------------
diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb
index 54107f4f8d0..99a981f9d2b 100644
--- a/docs/tutorials/custom_raster_dataset.ipynb
+++ b/docs/tutorials/custom_raster_dataset.ipynb
@@ -369,8 +369,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
- " all_bands = ['B02', 'B03', 'B04', 'B08']\n",
- " rgb_bands = ['B04', 'B03', 'B02']"
+ " all_bands = ('B02', 'B03', 'B04', 'B08')\n",
+ " rgb_bands = ('B04', 'B03', 'B02')"
]
},
{
@@ -432,8 +432,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
- " all_bands = ['B02', 'B03', 'B04', 'B08']\n",
- " rgb_bands = ['B04', 'B03', 'B02']\n",
+ " all_bands = ('B02', 'B03', 'B04', 'B08')\n",
+ " rgb_bands = ('B04', 'B03', 'B02')\n",
"\n",
" def plot(self, sample):\n",
" # Find the correct band index order\n",
diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py
index d93413bdd25..2708346482f 100755
--- a/experiments/ssl4eo/download_ssl4eo.py
+++ b/experiments/ssl4eo/download_ssl4eo.py
@@ -125,7 +125,7 @@ def filter_collection(
if filtered.size().getInfo() == 0:
raise ee.EEException(
- f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501
+ f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.'
)
return filtered
diff --git a/experiments/ssl4eo/sample_ssl4eo.py b/experiments/ssl4eo/sample_ssl4eo.py
index 68d1056df55..69f82f10283 100755
--- a/experiments/ssl4eo/sample_ssl4eo.py
+++ b/experiments/ssl4eo/sample_ssl4eo.py
@@ -47,7 +47,7 @@
def get_world_cities(
download_root: str = 'world_cities', size: int = 10000
) -> pd.DataFrame:
- url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501
+ url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip'
filename = 'worldcities.csv'
download_and_extract_archive(url, download_root)
cols = ['city', 'lat', 'lng', 'population']
diff --git a/pyproject.toml b/pyproject.toml
index 646e77ff41a..ccc94cd766c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -268,7 +268,7 @@ quote-style = "single"
skip-magic-trailing-comma = true
[tool.ruff.lint]
-extend-select = ["ANN", "D", "I", "NPY201", "UP"]
+extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"]
ignore = ["ANN101", "ANN102", "ANN401"]
[tool.ruff.lint.per-file-ignores]
diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py
index 39d41f5d945..67323262452 100755
--- a/tests/data/dfc2022/data.py
+++ b/tests/data/dfc2022/data.py
@@ -19,36 +19,36 @@
train_set = [
{
- 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
- 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501
+ 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif',
+ 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif',
+ 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif',
},
{
- 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
- 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501
+ 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif',
+ 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif',
+ 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif',
},
]
unlabeled_set = [
{
- 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
+ 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif',
+ 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif',
},
{
- 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
+ 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif',
+ 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif',
},
]
val_set = [
{
- 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
+ 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif',
+ 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif',
},
{
- 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501
- 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
+ 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif',
+ 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif',
},
]
diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py
index 68fa8ffe397..e3197ddde12 100644
--- a/tests/data/seasonet/data.py
+++ b/tests/data/seasonet/data.py
@@ -112,7 +112,7 @@
# Compute checksums
with open(archive, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
- print(f'{season}: {repr(md5)}')
+ print(f'{season}: {md5!r}')
# Write meta.csv
with open('meta.csv', 'w') as f:
@@ -121,7 +121,7 @@
# Compute checksums
with open('meta.csv', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
- print(f'meta.csv: {repr(md5)}')
+ print(f'meta.csv: {md5!r}')
os.makedirs('splits', exist_ok=True)
@@ -138,4 +138,4 @@
# Compute checksums
with open('splits.zip', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
- print(f'splits: {repr(md5)}')
+ print(f'splits: {md5!r}')
diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py
index 5354d325fed..3b2d4fc63f7 100644
--- a/tests/datasets/test_eurocrops.py
+++ b/tests/datasets/test_eurocrops.py
@@ -83,5 +83,5 @@ def test_invalid_query(self, dataset: EuroCrops) -> None:
dataset[query]
def test_integrity_error(self, dataset: EuroCrops) -> None:
- dataset.zenodo_files = [('AA.zip', 'invalid')]
+ dataset.zenodo_files = (('AA.zip', 'invalid'),)
assert not dataset._check_integrity()
diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py
index 28c50303b5e..8ff30531fee 100644
--- a/tests/datasets/test_geo.py
+++ b/tests/datasets/test_geo.py
@@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset):
class CustomSentinelDataset(Sentinel2):
- all_bands: list[str] = []
+ all_bands: tuple[str, ...] = ()
separate_files = False
@@ -356,7 +356,7 @@ def test_no_data(self, tmp_path: Path) -> None:
def test_no_all_bands(self) -> None:
root = os.path.join('tests', 'data', 'sentinel2')
- bands = ['B04', 'B03', 'B02']
+ bands = ('B04', 'B03', 'B02')
transforms = nn.Identity()
cache = True
msg = (
diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py
index a0b2b61add7..b0c13b6075b 100644
--- a/tests/trainers/test_byol.py
+++ b/tests/trainers/test_byol.py
@@ -73,7 +73,7 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
@pytest.fixture
def weights(self) -> WeightsEnum:
diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py
index 40c4ee3d0d2..be8132c808b 100644
--- a/tests/trainers/test_classification.py
+++ b/tests/trainers/test_classification.py
@@ -103,13 +103,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
@@ -259,13 +259,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py
index 035bdacc260..742cda3c371 100644
--- a/tests/trainers/test_detection.py
+++ b/tests/trainers/test_detection.py
@@ -97,13 +97,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
diff --git a/tests/trainers/test_iobench.py b/tests/trainers/test_iobench.py
index f67d19582ac..0fbde73bdc8 100644
--- a/tests/trainers/test_iobench.py
+++ b/tests/trainers/test_iobench.py
@@ -27,12 +27,12 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None:
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py
index 24b9fdb49cd..32c002dc573 100644
--- a/tests/trainers/test_moco.py
+++ b/tests/trainers/test_moco.py
@@ -63,7 +63,7 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):
diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py
index 349fa6ab7d1..00c9da65321 100644
--- a/tests/trainers/test_regression.py
+++ b/tests/trainers/test_regression.py
@@ -84,13 +84,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
@@ -237,13 +237,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py
index 5f096da75e7..ea4b0646521 100644
--- a/tests/trainers/test_segmentation.py
+++ b/tests/trainers/test_segmentation.py
@@ -108,13 +108,13 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
try:
- main(['test'] + args)
+ main(['test', *args])
except MisconfigurationException:
pass
try:
- main(['predict'] + args)
+ main(['predict', *args])
except MisconfigurationException:
pass
diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py
index eacf2c0ed70..7e1292ab7c0 100644
--- a/tests/trainers/test_simclr.py
+++ b/tests/trainers/test_simclr.py
@@ -63,7 +63,7 @@ def test_trainer(
'1',
]
- main(['fit'] + args)
+ main(['fit', *args])
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):
diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py
index f1ed2346164..1160f037366 100644
--- a/torchgeo/datamodules/seco.py
+++ b/torchgeo/datamodules/seco.py
@@ -37,7 +37,7 @@ def __init__(
seasons = kwargs.get('seasons', 1)
# Normalization only available for RGB dataset, defined here:
- # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
+ # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py
index 64701cef519..99126698ffa 100644
--- a/torchgeo/datamodules/so2sat.py
+++ b/torchgeo/datamodules/so2sat.py
@@ -3,7 +3,7 @@
"""So2Sat datamodule."""
-from typing import Any
+from typing import Any, ClassVar
import torch
from torch import Generator, Tensor
@@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule):
"train" set and use the "test" set as the test set.
"""
- means_per_version: dict[str, Tensor] = {
+ means_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
-0.00003591224260,
@@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule):
}
means_per_version['3_culture_10'] = means_per_version['2']
- stds_per_version: dict[str, Tensor] = {
+ stds_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
0.17555201,
diff --git a/torchgeo/datamodules/ssl4eo.py b/torchgeo/datamodules/ssl4eo.py
index 6ad558dcf87..f0b1ecdee46 100644
--- a/torchgeo/datamodules/ssl4eo.py
+++ b/torchgeo/datamodules/ssl4eo.py
@@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""
- # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
+ # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
mean = torch.tensor(0)
std = torch.tensor(10000)
diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py
index 6dcd690c735..b46695957be 100644
--- a/torchgeo/datasets/advance.py
+++ b/torchgeo/datasets/advance.py
@@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset):
* `scipy `_ to load the audio files to tensors
"""
- urls = [
+ urls = (
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
- ]
- filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip']
- md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31']
- directories = ['vision', 'sound']
- classes = [
+ )
+ filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip')
+ md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31')
+ directories = ('vision', 'sound')
+ classes: tuple[str, ...] = (
'airport',
'beach',
'bridge',
@@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset):
'sparse shrub land',
'sports land',
'train station',
- ]
+ )
def __init__(
self,
@@ -119,7 +119,7 @@ def __init__(
raise DatasetNotFoundError(self)
self.files = self._load_files(self.root)
- self.classes = sorted({f['cls'] for f in self.files})
+ self.classes = tuple(sorted({f['cls'] for f in self.files}))
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}
def __getitem__(self, index: int) -> dict[str, Tensor]:
diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py
index 1b80e11555b..1ceaa9c9d3d 100644
--- a/torchgeo/datasets/agb_live_woody_density.py
+++ b/torchgeo/datasets/agb_live_woody_density.py
@@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
is_image = False
- url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501
+ url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326'
base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'
diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py
index c1eae6de222..e40fca5eecf 100644
--- a/torchgeo/datasets/agrifieldnet.py
+++ b/torchgeo/datasets/agrifieldnet.py
@@ -7,7 +7,7 @@
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@@ -90,8 +90,8 @@ class AgriFieldNet(RasterDataset):
_(?PB[0-9A-Z]{2})_10m
"""
- rgb_bands = ['B04', 'B03', 'B02']
- all_bands = [
+ rgb_bands = ('B04', 'B03', 'B02')
+ all_bands = (
'B01',
'B02',
'B03',
@@ -104,9 +104,9 @@ class AgriFieldNet(RasterDataset):
'B09',
'B11',
'B12',
- ]
+ )
- cmap = {
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),
diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py
index 3b0caf607ed..12b8c38141c 100644
--- a/torchgeo/datasets/airphen.py
+++ b/torchgeo/datasets/airphen.py
@@ -40,8 +40,8 @@ class Airphen(RasterDataset):
# Each camera measures a custom set of spectral bands chosen at purchase time.
# Hiphen offers 8 bands to choose from, sorted from short to long wavelength.
- all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8']
- rgb_bands = ['B4', 'B3', 'B1']
+ all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8')
+ rgb_bands = ('B4', 'B3', 'B1')
def plot(
self,
diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py
index 70bd76373d5..4dd1ae927de 100644
--- a/torchgeo/datasets/benin_cashews.py
+++ b/torchgeo/datasets/benin_cashews.py
@@ -147,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
)
rgb_bands = ('B04', 'B03', 'B02')
- classes = [
+ classes = (
'No data',
'Well-managed planatation',
'Poorly-managed planatation',
@@ -155,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
'Residential',
'Background',
'Uncertain',
- ]
+ )
# Same for all tiles
tile_height = 1186
@@ -199,11 +199,13 @@ def __init__(
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
- for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
- self.tile_height - self.chip_size
+ for y in [
+ *list(range(0, self.tile_height - self.chip_size, stride)),
+ self.tile_height - self.chip_size,
]:
- for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
- self.tile_width - self.chip_size
+ for x in [
+ *list(range(0, self.tile_width - self.chip_size, stride)),
+ self.tile_width - self.chip_size,
]:
self.chips_metadata.append((y, x))
diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py
index 0f4c94565e1..38669cd6ff1 100644
--- a/torchgeo/datasets/bigearthnet.py
+++ b/torchgeo/datasets/bigearthnet.py
@@ -7,6 +7,7 @@
import json
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset):
* https://doi.org/10.1109/IGARSS.2019.8900532
- """ # noqa: E501
+ """
- class_sets = {
+ class_sets: ClassVar[dict[int, list[str]]] = {
19: [
'Urban fabric',
'Industrial or commercial units',
@@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset):
],
}
- label_converter = {
+ label_converter: ClassVar[dict[int, int]] = {
0: 0,
1: 0,
2: 1,
@@ -232,24 +233,24 @@ class BigEarthNet(NonGeoDataset):
42: 18,
}
- splits_metadata = {
+ splits_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
- 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501
+ 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false',
'filename': 'bigearthnet-train.csv',
'md5': '623e501b38ab7b12fe44f0083c00986d',
},
'val': {
- 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501
+ 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false',
'filename': 'bigearthnet-val.csv',
'md5': '22efe8ed9cbd71fa10742ff7df2b7978',
},
'test': {
- 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501
+ 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false',
'filename': 'bigearthnet-test.csv',
'md5': '697fb90677e30571b9ac7699b7e5b432',
},
}
- metadata = {
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
's1': {
'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz',
'md5': '94ced73440dea8c7b9645ee738c5a172',
diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py
index dc757b96c4b..5d3fe4b631f 100644
--- a/torchgeo/datasets/biomassters.py
+++ b/torchgeo/datasets/biomassters.py
@@ -50,7 +50,7 @@ class BioMassters(NonGeoDataset):
.. versionadded:: 0.5
"""
- valid_splits = ['train', 'test']
+ valid_splits = ('train', 'test')
valid_sensors = ('S1', 'S2')
metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv'
diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py
index f91993763dd..ec6afdb34b7 100644
--- a/torchgeo/datasets/cbf.py
+++ b/torchgeo/datasets/cbf.py
@@ -30,7 +30,7 @@ class CanadianBuildingFootprints(VectorDataset):
# https://github.com/microsoft/CanadianBuildingFootprints/issues/11
url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/'
- provinces_territories = [
+ provinces_territories = (
'Alberta',
'BritishColumbia',
'Manitoba',
@@ -44,8 +44,8 @@ class CanadianBuildingFootprints(VectorDataset):
'Quebec',
'Saskatchewan',
'YukonTerritory',
- ]
- md5s = [
+ )
+ md5s = (
'8b4190424e57bb0902bd8ecb95a9235b',
'fea05d6eb0006710729c675de63db839',
'adf11187362624d68f9c69aaa693c46f',
@@ -59,7 +59,7 @@ class CanadianBuildingFootprints(VectorDataset):
'9ff4417ae00354d39a0cf193c8df592c',
'a51078d8e60082c7d3a3818240da6dd5',
'c11f3bd914ecabd7cac2cb2871ec0261',
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py
index 25e42d2b030..07bd8193b79 100644
--- a/torchgeo/datasets/cdl.py
+++ b/torchgeo/datasets/cdl.py
@@ -6,7 +6,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@@ -38,7 +38,7 @@ class CDL(RasterDataset):
If you use this dataset in your research, please cite it using the following format:
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
- """ # noqa: E501
+ """
filename_glob = '*_30m_cdls.tif'
filename_regex = r"""
@@ -49,8 +49,8 @@ class CDL(RasterDataset):
date_format = '%Y'
is_image = False
- url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501
- md5s = {
+ url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip'
+ md5s: ClassVar[dict[int, str]] = {
2023: '8c7685d6278d50c554f934b16a6076b7',
2022: '754cf50670cdfee511937554785de3e6',
2021: '27606eab08fe975aa138baad3e5dfcd8',
@@ -69,7 +69,7 @@ class CDL(RasterDataset):
2008: '0610f2f17ab60a9fbb3baeb7543993a4',
}
- cmap = {
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),
diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py
index 61eefbdf35c..ba773607a54 100644
--- a/torchgeo/datasets/chabud.py
+++ b/torchgeo/datasets/chabud.py
@@ -4,7 +4,8 @@
"""ChaBuD dataset."""
import os
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset):
.. versionadded:: 0.6
"""
- all_bands = [
+ all_bands = (
'B01',
'B02',
'B03',
@@ -66,10 +67,10 @@ class ChaBuD(NonGeoDataset):
'B09',
'B11',
'B12',
- ]
- rgb_bands = ['B04', 'B03', 'B02']
- folds = {'train': [1, 2, 3, 4], 'val': [0]}
- url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501
+ )
+ rgb_bands = ('B04', 'B03', 'B02')
+ folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]}
+ url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5'
filename = 'train_eval.hdf5'
md5 = '15d78fb825f9a81dad600db828d22c08'
@@ -77,7 +78,7 @@ def __init__(
self,
root: Path = 'data',
split: str = 'train',
- bands: list[str] = all_bands,
+ bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py
index c3dd0ae88d0..46dbaefc806 100644
--- a/torchgeo/datasets/chesapeake.py
+++ b/torchgeo/datasets/chesapeake.py
@@ -9,7 +9,7 @@
import sys
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import fiona
import matplotlib.pyplot as plt
@@ -39,7 +39,7 @@ class Chesapeake(RasterDataset, ABC):
The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates
characterization of the landscape and land change for and between discrete time
- periods. The database was developed by the University of Vermont’s Spatial Analysis
+ periods. The database was developed by the University of Vermont's Spatial Analysis
Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological
Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake
Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate
@@ -83,7 +83,7 @@ def state(self) -> str:
"""State abbreviation."""
return self.__class__.__name__[-2:].lower()
- cmap = {
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
11: (0, 92, 230, 255),
12: (0, 92, 230, 255),
13: (0, 92, 230, 255),
@@ -255,7 +255,7 @@ def plot(
class ChesapeakeDC(Chesapeake):
"""This subset of the dataset contains data only for Washington, D.C."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2013: '9f1df21afbb9d5c0fcf33af7f6750a7f',
2017: 'c45e4af2950e1c93ecd47b61af296d9b',
}
@@ -264,7 +264,7 @@ class ChesapeakeDC(Chesapeake):
class ChesapeakeDE(Chesapeake):
"""This subset of the dataset contains data only for Delaware."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2013: '5850d96d897babba85610658aeb5951a',
2018: 'ee94c8efeae423d898677104117bdebc',
}
@@ -273,7 +273,7 @@ class ChesapeakeDE(Chesapeake):
class ChesapeakeMD(Chesapeake):
"""This subset of the dataset contains data only for Maryland."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2013: '9c3ca5040668d15284c1bd64b7d6c7a0',
2018: '0647530edf8bec6e60f82760dcc7db9c',
}
@@ -282,7 +282,7 @@ class ChesapeakeMD(Chesapeake):
class ChesapeakeNY(Chesapeake):
"""This subset of the dataset contains data only for New York."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2013: '38a29b721610ba661a7f8b6ec71a48b7',
2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3',
}
@@ -291,7 +291,7 @@ class ChesapeakeNY(Chesapeake):
class ChesapeakePA(Chesapeake):
"""This subset of the dataset contains data only for Pennsylvania."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2013: '86febd603a120a49ef7d23ef486152a3',
2017: 'b11d92e4471e8cb887c790d488a338c1',
}
@@ -300,7 +300,7 @@ class ChesapeakePA(Chesapeake):
class ChesapeakeVA(Chesapeake):
"""This subset of the dataset contains data only for Virginia."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2014: '49c9700c71854eebd00de24d8488eb7c',
2018: '51731c8b5632978bfd1df869ea10db5b',
}
@@ -309,7 +309,7 @@ class ChesapeakeVA(Chesapeake):
class ChesapeakeWV(Chesapeake):
"""This subset of the dataset contains data only for West Virginia."""
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2014: '32fea42fae147bd58a83e3ea6cccfb94',
2018: '80f25dcba72e39685ab33215c5d97292',
}
@@ -337,16 +337,16 @@ class ChesapeakeCVPR(GeoDataset):
* https://doi.org/10.1109/cvpr.2019.01301
"""
- subdatasets = ['base', 'prior_extension']
- urls = {
- 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501
- 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501
+ subdatasets = ('base', 'prior_extension')
+ urls: ClassVar[dict[str, str]] = {
+ 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip',
+ 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1',
}
- filenames = {
+ filenames: ClassVar[dict[str, str]] = {
'base': 'cvpr_chesapeake_landcover.zip',
'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip',
}
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'base': '1225ccbb9590e9396875f221e5031514',
'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a',
}
@@ -354,7 +354,7 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
- lc_cmap = {
+ lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (38, 115, 0, 255),
@@ -374,7 +374,7 @@ class ChesapeakeCVPR(GeoDataset):
]
)
- valid_layers = [
+ valid_layers = (
'naip-new',
'naip-old',
'landsat-leaf-on',
@@ -383,8 +383,8 @@ class ChesapeakeCVPR(GeoDataset):
'lc',
'buildings',
'prior_from_cooccurrences_101_31_no_osm_no_buildings',
- ]
- states = ['de', 'md', 'va', 'wv', 'pa', 'ny']
+ )
+ states = ('de', 'md', 'va', 'wv', 'pa', 'ny')
splits = (
[f'{state}-train' for state in states]
+ [f'{state}-val' for state in states]
@@ -392,7 +392,7 @@ class ChesapeakeCVPR(GeoDataset):
)
# these are used to check the integrity of the dataset
- _files = [
+ _files = (
'de_1m_2013_extended-debuffered-test_tiles',
'de_1m_2013_extended-debuffered-train_tiles',
'de_1m_2013_extended-debuffered-val_tiles',
@@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset):
'wv_1m_2014_extended-debuffered-train_tiles',
'wv_1m_2014_extended-debuffered-val_tiles',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif',
- 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501
- 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501
+ 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif',
+ 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif',
- 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
+ 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson',
- ]
+ )
p_src_crs = pyproj.CRS('epsg:3857')
- p_transformers = {
+ p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform,
@@ -511,7 +511,7 @@ def __init__(
'lc': row['properties']['lc'],
'nlcd': row['properties']['nlcd'],
'buildings': row['properties']['buildings'],
- 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501
+ 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn,
},
)
diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py
index 7c7ed8b630c..e0ca0045e33 100644
--- a/torchgeo/datasets/cloud_cover.py
+++ b/torchgeo/datasets/cloud_cover.py
@@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -55,9 +56,9 @@ class CloudCoverDetection(NonGeoDataset):
"""
url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final'
- all_bands = ['B02', 'B03', 'B04', 'B08']
- rgb_bands = ['B04', 'B03', 'B02']
- splits = {'train': 'public', 'test': 'private'}
+ all_bands = ('B02', 'B03', 'B04', 'B08')
+ rgb_bands = ('B04', 'B03', 'B02')
+ splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'}
def __init__(
self,
diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py
index 91ddbf8c54a..61d5c4acafd 100644
--- a/torchgeo/datasets/cms_mangrove_canopy.py
+++ b/torchgeo/datasets/cms_mangrove_canopy.py
@@ -42,7 +42,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip'
md5 = '3e7f9f23bf971c25e828b36e6c5496e3'
- all_countries = [
+ all_countries = (
'AndamanAndNicobar',
'Angola',
'Anguilla',
@@ -164,9 +164,9 @@ class CMSGlobalMangroveCanopy(RasterDataset):
'VirginIslandsUs',
'WallisAndFutuna',
'Yemen',
- ]
+ )
- measurements = ['agb', 'hba95', 'hmax95']
+ measurements = ('agb', 'hba95', 'hmax95')
def __init__(
self,
diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py
index cae82e597fc..fa97fa87037 100644
--- a/torchgeo/datasets/cowc.py
+++ b/torchgeo/datasets/cowc.py
@@ -50,12 +50,12 @@ def base_url(self) -> str:
@property
@abc.abstractmethod
- def filenames(self) -> list[str]:
+ def filenames(self) -> tuple[str, ...]:
"""List of files to download."""
@property
@abc.abstractmethod
- def md5s(self) -> list[str]:
+ def md5s(self) -> tuple[str, ...]:
"""List of MD5 checksums of files to download."""
@property
@@ -239,7 +239,7 @@ class COWCCounting(COWC):
base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/'
)
- filenames = [
+ filenames = (
'COWC_train_list_64_class.txt.bz2',
'COWC_test_list_64_class.txt.bz2',
'COWC_Counting_Toronto_ISPRS.tbz',
@@ -248,8 +248,8 @@ class COWCCounting(COWC):
'COWC_Counting_Vaihingen_ISPRS.tbz',
'COWC_Counting_Columbus_CSUAV_AFRL.tbz',
'COWC_Counting_Utah_AGRC.tbz',
- ]
- md5s = [
+ )
+ md5s = (
'187543d20fa6d591b8da51136e8ef8fb',
'930cfd6e160a7b36db03146282178807',
'bc2613196dfa93e66d324ae43e7c1fdb',
@@ -258,7 +258,7 @@ class COWCCounting(COWC):
'4009c1e420566390746f5b4db02afdb9',
'daf8033c4e8ceebbf2c3cac3fabb8b10',
'777ec107ed2a3d54597a739ce74f95ad',
- ]
+ )
filename = 'COWC_{}_list_64_class.txt'
@@ -268,7 +268,7 @@ class COWCDetection(COWC):
base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/'
)
- filenames = [
+ filenames = (
'COWC_train_list_detection.txt.bz2',
'COWC_test_list_detection.txt.bz2',
'COWC_Detection_Toronto_ISPRS.tbz',
@@ -277,8 +277,8 @@ class COWCDetection(COWC):
'COWC_Detection_Vaihingen_ISPRS.tbz',
'COWC_Detection_Columbus_CSUAV_AFRL.tbz',
'COWC_Detection_Utah_AGRC.tbz',
- ]
- md5s = [
+ )
+ md5s = (
'c954a5a3dac08c220b10cfbeec83893c',
'c6c2d0a78f12a2ad88b286b724a57c1a',
'11af24f43b198b0f13c8e94814008a48',
@@ -287,7 +287,7 @@ class COWCDetection(COWC):
'23945d5b22455450a938382ccc2a8b27',
'f40522dc97bea41b10117d4a5b946a6f',
'195da7c9443a939a468c9f232fd86ee3',
- ]
+ )
filename = 'COWC_{}_list_detection.txt'
diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py
index 30f3a43f634..bb3e4b3f3c5 100644
--- a/torchgeo/datasets/cropharvest.py
+++ b/torchgeo/datasets/cropharvest.py
@@ -7,6 +7,7 @@
import json
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset):
"""
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
- all_bands = [
+ all_bands = (
'VV',
'VH',
'B2',
@@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset):
'elevation',
'slope',
'NDVI',
- ]
- rgb_bands = ['B4', 'B3', 'B2']
+ )
+ rgb_bands = ('B4', 'B3', 'B2')
features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1'
labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1'
- file_dict = {
+ file_dict: ClassVar[dict[str, dict[str, str]]] = {
'features': {
'url': features_url,
'filename': 'features.tar.gz',
diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py
index 4e262d5266f..2248dab4292 100644
--- a/torchgeo/datasets/cv4a_kenya_crop_type.py
+++ b/torchgeo/datasets/cv4a_kenya_crop_type.py
@@ -65,8 +65,8 @@ class CV4AKenyaCropType(NonGeoDataset):
"""
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
- tiles = list(map(str, range(4)))
- dates = [
+ tiles = tuple(map(str, range(4)))
+ dates = (
'20190606',
'20190701',
'20190706',
@@ -80,7 +80,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'20190924',
'20191004',
'20191103',
- ]
+ )
all_bands = (
'B01',
'B02',
@@ -96,7 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'B12',
'CLD',
)
- rgb_bands = ['B04', 'B03', 'B02']
+ rgb_bands = ('B04', 'B03', 'B02')
# Same for all tiles
tile_height = 3035
@@ -141,11 +141,13 @@ def __init__(
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
for tile_index in range(len(self.tiles)):
- for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
- self.tile_height - self.chip_size
+ for y in [
+ *list(range(0, self.tile_height - self.chip_size, stride)),
+ self.tile_height - self.chip_size,
]:
- for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
- self.tile_width - self.chip_size
+ for x in [
+ *list(range(0, self.tile_width - self.chip_size, stride)),
+ self.tile_width - self.chip_size,
]:
self.chips_metadata.append((tile_index, y, x))
diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py
index fcd9fb7bac2..51b82f9ff92 100644
--- a/torchgeo/datasets/deepglobelandcover.py
+++ b/torchgeo/datasets/deepglobelandcover.py
@@ -74,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset):
$ unzip deepglobe2018-landcover-segmentation-traindataset.zip
.. versionadded:: 0.3
- """ # noqa: E501
+ """
filename = 'data.zip'
data_root = 'data'
md5 = 'f32684b0b2bf6f8d604cd359a399c061'
- splits = ['train', 'test']
- classes = [
+ splits = ('train', 'test')
+ classes = (
'Urban land',
'Agriculture land',
'Rangeland',
@@ -88,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset):
'Water',
'Barren land',
'Unknown',
- ]
- colormap = [
+ )
+ colormap = (
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
@@ -97,7 +97,7 @@ class DeepGlobeLandCover(NonGeoDataset):
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
- ]
+ )
def __init__(
self,
@@ -246,12 +246,15 @@ def plot(
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
- sample['image'], sample['mask'], alpha=alpha, colors=self.colormap
+ sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
image2 = draw_semantic_segmentation_masks(
- sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap
+ sample['image'],
+ sample['prediction'],
+ alpha=alpha,
+ colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py
index 46d96e87d3f..edc79d8ae2b 100644
--- a/torchgeo/datasets/dfc2022.py
+++ b/torchgeo/datasets/dfc2022.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset):
* https://doi.org/10.1007/s10994-020-05943-y
.. versionadded:: 0.3
- """ # noqa: E501
+ """
- classes = [
+ classes = (
'No information',
'Urban fabric',
'Industrial, commercial, public, military, private and transport units',
@@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset):
'Wetlands',
'Water',
'Clouds and Shadows',
- ]
- colormap = [
+ )
+ colormap = (
'#231F20',
'#DB5F57',
'#DB9757',
@@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset):
'#579BDB',
'#0062FF',
'#231F20',
- ]
- metadata = {
+ )
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'labeled_train.zip',
'md5': '2e87d6a218e466dd0566797d7298c7a9',
diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py
index 1d71a8940e6..5e6d4b80d99 100644
--- a/torchgeo/datasets/enviroatlas.py
+++ b/torchgeo/datasets/enviroatlas.py
@@ -6,7 +6,7 @@
import os
import sys
from collections.abc import Callable, Sequence
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import fiona
import matplotlib.pyplot as plt
@@ -54,9 +54,9 @@ class EnviroAtlas(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
- valid_prior_layers = ['prior', 'prior_no_osm_no_buildings']
+ valid_prior_layers = ('prior', 'prior_no_osm_no_buildings')
- valid_layers = [
+ valid_layers = (
'naip',
'nlcd',
'roads',
@@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset):
'waterbodies',
'buildings',
'lc',
- ] + valid_prior_layers
+ *valid_prior_layers,
+ )
- cities = [
+ cities = (
'pittsburgh_pa-2010_1m',
'durham_nc-2012_1m',
'austin_tx-2012_1m',
'phoenix_az-2010_1m',
- ]
+ )
splits = (
[f'{state}-train' for state in cities[:1]]
+ [f'{state}-val' for state in cities[:1]]
@@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset):
)
# these are used to check the integrity of the dataset
- _files = [
+ _files = (
'austin_tx-2012_1m-test_tiles-debuffered',
'austin_tx-2012_1m-val5_tiles-debuffered',
'durham_nc-2012_1m-test_tiles-debuffered',
@@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset):
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif',
- 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501
- 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
+ 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif',
+ 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson',
- ]
+ )
p_src_crs = pyproj.CRS('epsg:3857')
- p_transformers = {
+ p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform,
@@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset):
dtype=np.uint8,
)
- highres_classes = [
+ highres_classes = (
'Unclassified',
'Water',
'Impervious Surface',
@@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset):
'Orchards',
'Woody Wetlands',
'Emergent Wetlands',
- ]
+ )
highres_cmap = ListedColormap(
[
[1.00000000, 1.00000000, 1.00000000],
diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py
index ebf1d91f70b..7855c8bb3cf 100644
--- a/torchgeo/datasets/etci2021.py
+++ b/torchgeo/datasets/etci2021.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset):
the ETCI competition.
"""
- bands = ['VV', 'VH']
- masks = ['flood', 'water_body']
- metadata = {
+ bands = ('VV', 'VH')
+ masks = ('flood', 'water_body')
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'train.zip',
'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f',
diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py
index 63a6f916526..1d65a55f60c 100644
--- a/torchgeo/datasets/eudem.py
+++ b/torchgeo/datasets/eudem.py
@@ -7,7 +7,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@@ -53,7 +53,7 @@ class EUDEM(RasterDataset):
zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip'
filename_regex = '(?P[eudem_v11]{10})_(?P[A-Z0-9]{6})'
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591',
'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571',
'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709',
diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py
index bf0f173d4c6..cb5bb2a5bc5 100644
--- a/torchgeo/datasets/eurocrops.py
+++ b/torchgeo/datasets/eurocrops.py
@@ -61,7 +61,7 @@ class EuroCrops(VectorDataset):
date_format = '%Y'
# Filename and md5 of files in this dataset on zenodo.
- zenodo_files = [
+ zenodo_files: tuple[tuple[str, str], ...] = (
('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'),
('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'),
('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'),
@@ -81,7 +81,7 @@ class EuroCrops(VectorDataset):
# Year is unknown for Romania portion (ny = no year).
# We skip since it is inconsistent with the rest of the data.
# ("RO_ny.zip", "648e1504097765b4b7f825decc838882"),
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py
index 8292257404c..26c1c860cf8 100644
--- a/torchgeo/datasets/eurosat.py
+++ b/torchgeo/datasets/eurosat.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable, Sequence
-from typing import cast
+from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
* https://ieeexplore.ieee.org/document/8519248
"""
- url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
filename = 'EuroSATallBands.zip'
md5 = '5ac12b3b2557aa56e1826e981e8e200e'
@@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset):
'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif'
)
- splits = ['train', 'val', 'test']
- split_urls = {
- 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501
- 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501
- 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501
+ splits = ('train', 'val', 'test')
+ split_urls: ClassVar[dict[str, str]] = {
+ 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
+ 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
+ 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
}
- split_md5s = {
+ split_md5s: ClassVar[dict[str, str]] = {
'train': '908f142e73d6acdf3f482c5e80d851b1',
'val': '95de90f2aa998f70a3b2416bfe0687b4',
'test': '7ae5ab94471417b6e315763121e67c5f',
@@ -93,7 +93,10 @@ class EuroSAT(NonGeoClassificationDataset):
rgb_bands = ('B04', 'B03', 'B02')
- BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands}
+ BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
+ 'all': all_band_names,
+ 'rgb': rgb_bands,
+ }
def __init__(
self,
@@ -302,12 +305,12 @@ class EuroSATSpatial(EuroSAT):
.. versionadded:: 0.6
"""
- split_urls = {
+ split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
}
- split_md5s = {
+ split_md5s: ClassVar[dict[str, str]] = {
'train': '7be3254be39f23ce4d4d144290c93292',
'val': 'acf392290050bb3df790dc8fc0ebf193',
'test': '5ec1733f9c16116bf0aa2d921fc613ef',
@@ -325,16 +328,16 @@ class EuroSAT100(EuroSAT):
.. versionadded:: 0.5
"""
- url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
filename = 'EuroSAT100.zip'
md5 = 'c21c649ba747e86eda813407ef17d596'
- split_urls = {
- 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501
- 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501
- 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501
+ split_urls: ClassVar[dict[str, str]] = {
+ 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
+ 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
+ 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
}
- split_md5s = {
+ split_md5s: ClassVar[dict[str, str]] = {
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
'val': '3e3f8b3c344182b8d126c4cc88f3f215',
'test': 'f908f151b950f270ad18e61153579794',
diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py
index 6d4337f57bc..d58968eaa19 100644
--- a/torchgeo/datasets/fair1m.py
+++ b/torchgeo/datasets/fair1m.py
@@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
-from typing import Any, cast
+from typing import Any, ClassVar, cast
from xml.etree.ElementTree import Element, parse
import matplotlib.patches as patches
@@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset):
.. versionadded:: 0.2
"""
- classes = {
+ classes: ClassVar[dict[str, dict[str, Any]]] = {
'Passenger Ship': {'id': 0, 'category': 'Ship'},
'Motorboat': {'id': 1, 'category': 'Ship'},
'Fishing Boat': {'id': 2, 'category': 'Ship'},
@@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset):
'Bridge': {'id': 36, 'category': 'Road'},
}
- filename_glob = {
+ filename_glob: ClassVar[dict[str, str]] = {
'train': os.path.join('train', '**', 'images', '*.tif'),
'val': os.path.join('validation', 'images', '*.tif'),
'test': os.path.join('test', 'images', '*.tif'),
}
- directories = {
+ directories: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
os.path.join('train', 'part1', 'images'),
os.path.join('train', 'part1', 'labelXml'),
@@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset):
os.path.join('validation', 'images'),
os.path.join('validation', 'labelXml'),
),
- 'test': (os.path.join('test', 'images')),
+ 'test': (os.path.join('test', 'images'),),
}
- paths = {
+ paths: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
os.path.join('train', 'part1', 'images.zip'),
os.path.join('train', 'part1', 'labelXml.zip'),
@@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset):
os.path.join('test', 'images2.zip'),
),
}
- urls = {
+ urls: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf',
'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u',
@@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset):
'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0',
),
}
- md5s = {
+ md5s: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
'a460fe6b1b5b276bf856ce9ac72d6568',
'80f833ff355f91445c92a0c0c1fa7414',
diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py
index be40dfcf6d6..9370488f503 100644
--- a/torchgeo/datasets/fire_risk.py
+++ b/torchgeo/datasets/fire_risk.py
@@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset):
md5 = 'a77b9a100d51167992ae8c51d26198a6'
filename = 'FireRisk.zip'
directory = 'FireRisk'
- splits = ['train', 'val']
- classes = [
+ splits = ('train', 'val')
+ classes = (
'High',
'Low',
'Moderate',
@@ -64,7 +64,7 @@ class FireRisk(NonGeoClassificationDataset):
'Very_High',
'Very_Low',
'Water',
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py
index 0eca812bb8b..33333b956ba 100644
--- a/torchgeo/datasets/forestdamage.py
+++ b/torchgeo/datasets/forestdamage.py
@@ -96,11 +96,8 @@ class ForestDamage(NonGeoDataset):
.. versionadded:: 0.3
"""
- classes = ['other', 'H', 'LD', 'HD']
- url = (
- 'https://lilablobssc.blob.core.windows.net/larch-casebearer/'
- 'Data_Set_Larch_Casebearer.zip'
- )
+ classes = ('other', 'H', 'LD', 'HD')
+ url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip'
data_dir = 'Data_Set_Larch_Casebearer'
md5 = '907815bcc739bff89496fac8f8ce63d7'
diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py
index f7362e9c692..68a7d853969 100644
--- a/torchgeo/datasets/geo.py
+++ b/torchgeo/datasets/geo.py
@@ -13,7 +13,7 @@
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import fiona
import fiona.transform
@@ -370,13 +370,13 @@ class RasterDataset(GeoDataset):
separate_files = False
#: Names of all available bands in the dataset
- all_bands: list[str] = []
+ all_bands: tuple[str, ...] = ()
#: Names of RGB bands in the dataset, used for plotting
- rgb_bands: list[str] = []
+ rgb_bands: tuple[str, ...] = ()
#: Color map for the dataset, used for plotting
- cmap: dict[int, tuple[int, int, int, int]] = {}
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
@property
def dtype(self) -> torch.dtype:
@@ -458,7 +458,7 @@ def __init__(
# See if file has a color map
if len(self.cmap) == 0:
try:
- self.cmap = src.colormap(1)
+ self.cmap = src.colormap(1) # type: ignore[misc]
except ValueError:
pass
diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py
index 078b83e6054..b42e6e58df6 100644
--- a/torchgeo/datasets/gid15.py
+++ b/torchgeo/datasets/gid15.py
@@ -66,8 +66,8 @@ class GID15(NonGeoDataset):
md5 = '615682bf659c3ed981826c6122c10c83'
filename = 'gid-15.zip'
directory = 'GID'
- splits = ['train', 'val', 'test']
- classes = [
+ splits = ('train', 'val', 'test')
+ classes = (
'background',
'industrial_land',
'urban_residential',
@@ -84,7 +84,7 @@ class GID15(NonGeoDataset):
'river',
'lake',
'pond',
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py
index 007466738d2..e117bf361c6 100644
--- a/torchgeo/datasets/globbiomass.py
+++ b/torchgeo/datasets/globbiomass.py
@@ -7,7 +7,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@@ -73,9 +73,9 @@ class GlobBiomass(RasterDataset):
is_image = False
dtype = torch.float32 # pixelwise regression
- measurements = ['agb', 'gsv']
+ measurements = ('agb', 'gsv')
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6',
'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79',
'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097',
diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py
index 67a9db80e87..6157579eb1b 100644
--- a/torchgeo/datasets/idtrees.py
+++ b/torchgeo/datasets/idtrees.py
@@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
-from typing import Any, cast, overload
+from typing import Any, ClassVar, cast, overload
import fiona
import matplotlib.pyplot as plt
@@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset):
.. versionadded:: 0.2
"""
- classes = {
+ classes: ClassVar[dict[str, str]] = {
'ACPE': 'Acer pensylvanicum L.',
'ACRU': 'Acer rubrum L.',
'ACSA3': 'Acer saccharum Marshall',
@@ -135,19 +135,22 @@ class IDTReeS(NonGeoDataset):
'ROPS': 'Robinia pseudoacacia L.',
'TSCA': 'Tsuga canadensis (L.) Carriere',
}
- metadata = {
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
- 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501
+ 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1',
'md5': '5ddfa76240b4bb6b4a7861d1d31c299c',
'filename': 'IDTREES_competition_train_v2.zip',
},
'test': {
- 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501
+ 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1',
'md5': 'b108931c84a70f2a38a8234290131c9b',
'filename': 'IDTREES_competition_test_v2.zip',
},
}
- directories = {'train': ['train'], 'test': ['task1', 'task2']}
+ directories: ClassVar[dict[str, list[str]]] = {
+ 'train': ['train'],
+ 'test': ['task1', 'task2'],
+ }
image_size = (200, 200)
def __init__(
diff --git a/torchgeo/datasets/iobench.py b/torchgeo/datasets/iobench.py
index 80376df9579..608a9ccc17a 100644
--- a/torchgeo/datasets/iobench.py
+++ b/torchgeo/datasets/iobench.py
@@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@@ -40,9 +40,9 @@ class IOBench(IntersectionDataset):
.. versionadded:: 0.6
"""
- url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz'
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'original': 'e3a908a0fd1c05c1af2f4c65724d59b3',
'raw': 'e9603990441007ce7bba73bb8ba7d217',
'preprocessed': '9801f1240b238cb17525c865e413d1fd',
@@ -54,7 +54,7 @@ def __init__(
split: str = 'preprocessed',
crs: CRS | None = None,
res: float | None = None,
- bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'],
+ bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'],
classes: list[int] = [0],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py
index 8adc3842c45..25830eaf3dc 100644
--- a/torchgeo/datasets/l7irish.py
+++ b/torchgeo/datasets/l7irish.py
@@ -8,7 +8,7 @@
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@@ -43,8 +43,8 @@ class L7IrishImage(RasterDataset):
"""
date_format = '%Y%m%d'
is_image = True
- rgb_bands = ['B30', 'B20', 'B10']
- all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80']
+ rgb_bands = ('B30', 'B20', 'B10')
+ all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80')
class L7IrishMask(RasterDataset):
@@ -59,7 +59,7 @@ class L7IrishMask(RasterDataset):
_newmask2015\.TIF$
"""
is_image = False
- classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
+ classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
@@ -158,11 +158,11 @@ class L7Irish(IntersectionDataset):
* https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109
.. versionadded:: 0.5
- """ # noqa: E501
+ """
- url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz'
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'austral': '0a34770b992a62abeb88819feb192436',
'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082',
'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa',
diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py
index 4865ec932b4..318efa2476a 100644
--- a/torchgeo/datasets/l8biome.py
+++ b/torchgeo/datasets/l8biome.py
@@ -7,7 +7,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable, Sequence
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@@ -36,8 +36,8 @@ class L8BiomeImage(RasterDataset):
"""
date_format = '%Y%j'
is_image = True
- rgb_bands = ['B4', 'B3', 'B2']
- all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
+ rgb_bands = ('B4', 'B3', 'B2')
+ all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11')
class L8BiomeMask(RasterDataset):
@@ -57,7 +57,7 @@ class L8BiomeMask(RasterDataset):
"""
date_format = '%Y%j'
is_image = False
- classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
+ classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
@@ -116,11 +116,11 @@ class L8Biome(IntersectionDataset):
* https://doi.org/10.1016/j.rse.2017.03.026
.. versionadded:: 0.5
- """ # noqa: E501
+ """
- url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz'
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'barren': '0eb691822d03dabd4f5ea8aadd0b41c3',
'forest': '4a5645596f6bb8cea44677f746ec676e',
'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c',
diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py
index 5273c03303c..d9a9643b524 100644
--- a/torchgeo/datasets/landcoverai.py
+++ b/torchgeo/datasets/landcoverai.py
@@ -9,7 +9,7 @@
import os
from collections.abc import Callable
from functools import lru_cache
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip'
filename = 'landcover.ai.v1.zip'
md5 = '3268c89070e8734b4e91d531c0617e03'
- classes = ['Background', 'Building', 'Woodland', 'Water', 'Road']
- cmap = {
+ classes = ('Background', 'Building', 'Woodland', 'Water', 'Road')
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
1: (97, 74, 74, 255),
2: (38, 115, 0, 255),
diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py
index 48647f5d247..8fb33b7c9cc 100644
--- a/torchgeo/datasets/landsat.py
+++ b/torchgeo/datasets/landsat.py
@@ -33,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC):
* `Surface Temperature `_
* `Surface Reflectance `_
* `U.S. Analysis Ready Data `_
- """ # noqa: E501
+ """
# https://www.usgs.gov/landsat-missions/landsat-collection-2
filename_regex = r"""
@@ -55,7 +55,7 @@ class Landsat(RasterDataset, abc.ABC):
@property
@abc.abstractmethod
- def default_bands(self) -> list[str]:
+ def default_bands(self) -> tuple[str, ...]:
"""Bands to load by default."""
def __init__(
@@ -145,8 +145,8 @@ class Landsat1(Landsat):
filename_glob = 'LM01_*_{}.*'
- default_bands = ['B4', 'B5', 'B6', 'B7']
- rgb_bands = ['B6', 'B5', 'B4']
+ default_bands = ('B4', 'B5', 'B6', 'B7')
+ rgb_bands = ('B6', 'B5', 'B4')
class Landsat2(Landsat1):
@@ -166,8 +166,8 @@ class Landsat4MSS(Landsat):
filename_glob = 'LM04_*_{}.*'
- default_bands = ['B1', 'B2', 'B3', 'B4']
- rgb_bands = ['B3', 'B2', 'B1']
+ default_bands = ('B1', 'B2', 'B3', 'B4')
+ rgb_bands = ('B3', 'B2', 'B1')
class Landsat4TM(Landsat):
@@ -175,8 +175,8 @@ class Landsat4TM(Landsat):
filename_glob = 'LT04_*_{}.*'
- default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
- rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
+ default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
+ rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat5MSS(Landsat4MSS):
@@ -196,8 +196,8 @@ class Landsat7(Landsat):
filename_glob = 'LE07_*_{}.*'
- default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
- rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
+ default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
+ rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat8(Landsat):
@@ -205,11 +205,11 @@ class Landsat8(Landsat):
filename_glob = 'LC08_*_{}.*'
- default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
- rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2']
+ default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
+ rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2')
class Landsat9(Landsat8):
- """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501
+ """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2)."""
filename_glob = 'LC09_*_{}.*'
diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py
index 9dbc68136db..e664365ab2d 100644
--- a/torchgeo/datasets/levircd.py
+++ b/torchgeo/datasets/levircd.py
@@ -7,6 +7,7 @@
import glob
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -26,8 +27,8 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
.. versionadded:: 0.6
"""
- splits: list[str] | dict[str, dict[str, str]]
- directories = ['A', 'B', 'label']
+ splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]]
+ directories = ('A', 'B', 'label')
def __init__(
self,
@@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase):
.. versionadded:: 0.6
"""
- splits = {
+ splits: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-',
'filename': 'train.zip',
@@ -336,7 +337,7 @@ class LEVIRCDPlus(LEVIRCDBase):
md5 = '1adf156f628aa32fb2e8fe6cada16c04'
filename = 'LEVIR-CD+.zip'
directory = 'LEVIR-CD+'
- splits = ['train', 'test']
+ splits = ('train', 'test')
def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.
diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py
index 8c987548f90..0398410a674 100644
--- a/torchgeo/datasets/loveda.py
+++ b/torchgeo/datasets/loveda.py
@@ -5,7 +5,8 @@
import glob
import os
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -57,10 +58,10 @@ class LoveDA(NonGeoDataset):
.. versionadded:: 0.2
"""
- scenes = ['urban', 'rural']
- splits = ['train', 'val', 'test']
+ scenes = ('urban', 'rural')
+ splits = ('train', 'val', 'test')
- info_dict = {
+ info_dict: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1',
'filename': 'Train.zip',
@@ -78,7 +79,7 @@ class LoveDA(NonGeoDataset):
},
}
- classes = [
+ classes = (
'background',
'building',
'road',
@@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset):
'forest',
'agriculture',
'no-data',
- ]
+ )
def __init__(
self,
root: Path = 'data',
split: str = 'train',
- scene: list[str] = ['urban', 'rural'],
+ scene: Sequence[str] = ['urban', 'rural'],
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py
index 66ce23ad70f..51289cf91cd 100644
--- a/torchgeo/datasets/mapinwild.py
+++ b/torchgeo/datasets/mapinwild.py
@@ -7,6 +7,7 @@
import shutil
from collections import defaultdict
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -36,7 +37,7 @@ class MapInWild(NonGeoDataset):
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
- images with the shape of 1920 × 1920 pixels. The images are weakly annotated
+ images with the shape of 1920 x 1920 pixels. The images are weakly annotated
from the World Database of Protected Areas (WDPA).
Dataset features:
@@ -54,9 +55,9 @@ class MapInWild(NonGeoDataset):
.. versionadded:: 0.5
"""
- url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501
+ url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/'
- modality_urls = {
+ modality_urls: ClassVar[dict[str, set[str]]] = {
'esa_wc': {'esa_wc/ESA_WC.zip'},
'viirs': {'viirs/VIIRS.zip'},
'mask': {'mask/mask.zip'},
@@ -72,7 +73,7 @@ class MapInWild(NonGeoDataset):
'split_IDs': {'split_IDs/split_IDs.csv'},
}
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92',
'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4',
'mask.zip': '87c83a23a73998ad60d448d240b66225',
@@ -91,9 +92,12 @@ class MapInWild(NonGeoDataset):
'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f',
}
- mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
+ mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
+ 1: (0, 153, 0),
+ 0: (255, 255, 255),
+ }
- wc_cmap = {
+ wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
10: (0, 160, 0),
20: (150, 100, 0),
30: (255, 180, 0),
diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py
index 1938aefbafc..14e0ba7dc12 100644
--- a/torchgeo/datasets/millionaid.py
+++ b/torchgeo/datasets/millionaid.py
@@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset):
.. versionadded:: 0.3
"""
- multi_label_categories = [
+ multi_label_categories = (
'agriculture_land',
'airport_area',
'apartment',
@@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset):
'wind_turbine',
'woodland',
'works',
- ]
+ )
- multi_class_categories = [
+ multi_class_categories = (
'apartment',
'apron',
'bare_land',
@@ -176,17 +176,17 @@ class MillionAID(NonGeoDataset):
'wastewater_plant',
'wind_turbine',
'works',
- ]
+ )
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'train': '1b40503cafa9b0601653ca36cd788852',
'test': '51a63ee3eeb1351889eacff349a983d8',
}
- filenames = {'train': 'train.zip', 'test': 'test.zip'}
+ filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'}
- tasks = ['multi-class', 'multi-label']
- splits = ['train', 'test']
+ tasks = ('multi-class', 'multi-label')
+ splits = ('train', 'test')
def __init__(
self,
diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py
index d8185782367..326dccd6d72 100644
--- a/torchgeo/datasets/naip.py
+++ b/torchgeo/datasets/naip.py
@@ -45,8 +45,8 @@ class NAIP(RasterDataset):
"""
# Plotting
- all_bands = ['R', 'G', 'B', 'NIR']
- rgb_bands = ['R', 'G', 'B']
+ all_bands = ('R', 'G', 'B', 'NIR')
+ rgb_bands = ('R', 'G', 'B')
def plot(
self,
diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py
index 83163391735..96633b2e35b 100644
--- a/torchgeo/datasets/nccm.py
+++ b/torchgeo/datasets/nccm.py
@@ -4,7 +4,7 @@
"""Northeastern China Crop Map Dataset."""
from collections.abc import Callable, Iterable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@@ -57,23 +57,23 @@ class NCCM(RasterDataset):
date_format = '%Y'
is_image = False
- urls = {
+ urls: ClassVar[dict[int, str]] = {
2019: 'https://figshare.com/ndownloader/files/25070540',
2018: 'https://figshare.com/ndownloader/files/25070624',
2017: 'https://figshare.com/ndownloader/files/25070582',
}
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2019: '0d062bbd42e483fdc8239d22dba7020f',
2018: 'b3bb4894478d10786aa798fb11693ec1',
2017: 'd047fbe4a85341fa6248fd7e0badab6c',
}
- fnames = {
+ fnames: ClassVar[dict[int, str]] = {
2019: 'CDL2019_clip.tif',
2018: 'CDL2018_clip1.tif',
2017: 'CDL2017_clip.tif',
}
- cmap = {
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 255, 0, 255),
1: (255, 0, 0, 255),
2: (255, 255, 0, 255),
diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py
index e7113b6709f..7f4498c76da 100644
--- a/torchgeo/datasets/nlcd.py
+++ b/torchgeo/datasets/nlcd.py
@@ -7,7 +7,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@@ -67,7 +67,7 @@ class NLCD(RasterDataset):
* 2019: https://doi.org/10.5066/P9KZCM54
.. versionadded:: 0.5
- """ # noqa: E501
+ """
filename_glob = 'nlcd_*_land_cover_l48_*.img'
filename_regex = (
@@ -79,7 +79,7 @@ class NLCD(RasterDataset):
url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip'
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2001: '538166a4d783204764e3df3b221fc4cd',
2006: '67454e7874a00294adb9442374d0c309',
2011: 'ea524c835d173658eeb6fa3c8e6b917b',
@@ -87,7 +87,7 @@ class NLCD(RasterDataset):
2019: '82851c3f8105763b01c83b4a9e6f3961',
}
- cmap = {
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),
diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py
index 9e341b6822d..f19f4c9a2ce 100644
--- a/torchgeo/datasets/openbuildings.py
+++ b/torchgeo/datasets/openbuildings.py
@@ -9,7 +9,7 @@
import pathlib
import sys
from collections.abc import Callable, Iterable
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import fiona
import fiona.transform
@@ -61,7 +61,7 @@ class OpenBuildings(VectorDataset):
.. versionadded:: 0.3
"""
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17',
'04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1',
'05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670',
diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py
index 808ca93ab09..28f7714a7c6 100644
--- a/torchgeo/datasets/oscd.py
+++ b/torchgeo/datasets/oscd.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -50,7 +51,7 @@ class OSCD(NonGeoDataset):
.. versionadded:: 0.2
"""
- urls = {
+ urls: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': (
'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download'
),
@@ -61,7 +62,7 @@ class OSCD(NonGeoDataset):
'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download'
),
}
- md5s = {
+ md5s: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': (
'c50d4a2941da64e03a47ac4dec63d915'
),
@@ -75,9 +76,9 @@ class OSCD(NonGeoDataset):
zipfile_glob = '*Onera*.zip'
filename_glob = '*Onera*'
- splits = ['train', 'test']
+ splits = ('train', 'test')
- colormap = ['blue']
+ colormap = ('blue',)
all_bands = (
'B01',
@@ -319,7 +320,7 @@ def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]':
torch.from_numpy(rgb_img),
sample['mask'],
alpha=alpha,
- colors=self.colormap,
+ colors=list(self.colormap),
)
return array
diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py
index bdad515f66b..430a6330d89 100644
--- a/torchgeo/datasets/pastis.py
+++ b/torchgeo/datasets/pastis.py
@@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import fiona
import matplotlib.pyplot as plt
@@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset):
.. versionadded:: 0.5
"""
- classes = [
+ classes = (
'background', # all non-agricultural land
'meadow',
'soft_winter_wheat',
@@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset):
'mixed_cereal',
'sorghum',
'void_label', # for parcels mostly outside their patch
- ]
- cmap = {
+ )
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (174, 199, 232, 255),
2: (255, 127, 14, 255),
@@ -118,7 +119,7 @@ class PASTIS(NonGeoDataset):
filename = 'PASTIS-R.zip'
url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1'
md5 = '4887513d6c2d2b07fa935d325bd53e09'
- prefix = {
+ prefix: ClassVar[dict[str, str]] = {
's2': os.path.join('DATA_S2', 'S2_'),
's1a': os.path.join('DATA_S1A', 'S1A_'),
's1d': os.path.join('DATA_S1D', 'S1D_'),
@@ -232,7 +233,7 @@ def _load_semantic_targets(self, index: int) -> Tensor:
Returns:
the target mask
"""
- # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
+ # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201
# even though the mask file is 3 bands, we just select the first band
array = np.load(self.files[index]['semantic'])[0].astype(np.uint8)
tensor = torch.from_numpy(array).long()
diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py
index 479ca3cf170..51f1ebd0441 100644
--- a/torchgeo/datasets/potsdam.py
+++ b/torchgeo/datasets/potsdam.py
@@ -5,6 +5,7 @@
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -54,12 +55,12 @@ class Potsdam2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2
- """ # noqa: E501
+ """
- filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']
- md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db']
+ filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip')
+ md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db')
image_root = '4_Ortho_RGBIR'
- splits = {
+ splits: ClassVar[dict[str, list[str]]] = {
'train': [
'top_potsdam_2_10',
'top_potsdam_2_11',
@@ -103,22 +104,22 @@ class Potsdam2D(NonGeoDataset):
'top_potsdam_7_13',
],
}
- classes = [
+ classes = (
'Clutter/background',
'Impervious surfaces',
'Building',
'Low Vegetation',
'Tree',
'Car',
- ]
- colormap = [
+ )
+ colormap = (
(255, 0, 0),
(255, 255, 255),
(0, 0, 255),
(0, 255, 255),
(0, 255, 0),
(255, 255, 0),
- ]
+ )
def __init__(
self,
@@ -257,7 +258,7 @@ def plot(
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
- sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
+ sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
@@ -265,7 +266,7 @@ def plot(
sample['image'][:3],
sample['prediction'],
alpha=alpha,
- colors=self.colormap,
+ colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py
index 3fd2501dd0f..811d79cff08 100644
--- a/torchgeo/datasets/quakeset.py
+++ b/torchgeo/datasets/quakeset.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable
-from typing import Any, cast
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -61,8 +61,12 @@ class QuakeSet(NonGeoDataset):
filename = 'earthquakes.h5'
url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5'
md5 = '76fc7c76b7ca56f4844d852e175e1560'
- splits = {'train': 'train', 'val': 'validation', 'test': 'test'}
- classes = ['unaffected_area', 'earthquake_affected_area']
+ splits: ClassVar[dict[str, str]] = {
+ 'train': 'train',
+ 'val': 'validation',
+ 'test': 'test',
+ }
+ classes = ('unaffected_area', 'earthquake_affected_area')
def __init__(
self,
diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py
index bd28ab83f5d..930799ed205 100644
--- a/torchgeo/datasets/reforestree.py
+++ b/torchgeo/datasets/reforestree.py
@@ -56,7 +56,7 @@ class ReforesTree(NonGeoDataset):
.. versionadded:: 0.3
"""
- classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber']
+ classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber')
url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1'
md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302'
diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py
index 0c8108a78b9..fd33b634fde 100644
--- a/torchgeo/datasets/resisc45.py
+++ b/torchgeo/datasets/resisc45.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable
-from typing import cast
+from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -98,13 +98,13 @@ class RESISC45(NonGeoClassificationDataset):
filename = 'NWPU-RESISC45.zip'
directory = 'NWPU-RESISC45'
- splits = ['train', 'val', 'test']
- split_urls = {
+ splits = ('train', 'val', 'test')
+ split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt',
'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt',
'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt',
}
- split_md5s = {
+ split_md5s: ClassVar[dict[str, str]] = {
'train': 'b5a4c05a37de15e4ca886696a85c403e',
'val': 'a0770cee4c5ca20b8c32bbd61e114805',
'test': '3dda9e4988b47eb1de9f07993653eb08',
diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py
index 510039d8dcd..50459f5da9f 100644
--- a/torchgeo/datasets/rwanda_field_boundary.py
+++ b/torchgeo/datasets/rwanda_field_boundary.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -55,11 +56,11 @@ class RwandaFieldBoundary(NonGeoDataset):
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
- splits = {'train': 57, 'test': 13}
+ splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13}
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04')
rgb_bands = ('B03', 'B02', 'B01')
- classes = ['No field-boundary', 'Field-boundary']
+ classes = ('No field-boundary', 'Field-boundary')
def __init__(
self,
diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py
index 6d19fed7fcc..04cac92b36c 100644
--- a/torchgeo/datasets/seasonet.py
+++ b/torchgeo/datasets/seasonet.py
@@ -6,6 +6,7 @@
import os
import random
from collections.abc import Callable, Collection, Iterable
+from typing import ClassVar
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
@@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset):
.. versionadded:: 0.5
"""
- metadata = [
+ metadata = (
{
'name': 'spring',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip',
'md5': 'de4cdba7b6196aff624073991b187561',
},
{
'name': 'summer',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip',
'md5': '6a54d4e134d27ae4eb03f180ee100550',
},
{
'name': 'fall',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip',
'md5': '5f94920fe41a63c6bfbab7295f7d6b95',
},
{
'name': 'winter',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip',
'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48',
},
{
'name': 'snow',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip',
'md5': 'e1b300994143f99ebb03f51d6ab1cbe6',
},
{
'name': 'splits',
'ext': '.zip',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip',
'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed',
},
{
'name': 'meta.csv',
'ext': '',
- 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501
+ 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv',
'md5': '43ea07974936a6bf47d989c32e16afe7',
},
- ]
- classes = [
+ )
+ classes = (
'Continuous urban fabric',
'Discontinuous urban fabric',
'Industrial or commercial units',
@@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset):
'Coastal lagoons',
'Estuaries',
'Sea and ocean',
- ]
- all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'}
+ )
+ all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'})
all_bands = ('10m_RGB', '10m_IR', '20m', '60m')
- band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2}
- splits = ['train', 'val', 'test']
- cmap = {
+ band_nums: ClassVar[dict[str, int]] = {
+ '10m_RGB': 3,
+ '10m_IR': 1,
+ '20m': 6,
+ '60m': 2,
+ }
+ splits = ('train', 'val', 'test')
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (230, 000, 77, 255),
1: (255, 000, 000, 255),
2: (204, 77, 242, 255),
@@ -331,7 +337,7 @@ def _load_image(self, index: int) -> Tensor:
for band in self.bands:
with rasterio.open(f'{path}_{band}.tif') as f:
array = f.read(
- out_shape=[f.count] + list(self.image_size),
+ out_shape=[f.count, *list(self.image_size)],
out_dtype='int32',
resampling=Resampling.bilinear,
)
diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py
index 74b8ebba54a..0d1df7f7bbe 100644
--- a/torchgeo/datasets/seco.py
+++ b/torchgeo/datasets/seco.py
@@ -5,7 +5,8 @@
import os
import random
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -37,7 +38,7 @@ class SeasonalContrastS2(NonGeoDataset):
* https://arxiv.org/pdf/2103.16607.pdf
"""
- all_bands = [
+ all_bands = (
'B1',
'B2',
'B3',
@@ -50,10 +51,10 @@ class SeasonalContrastS2(NonGeoDataset):
'B9',
'B11',
'B12',
- ]
- rgb_bands = ['B4', 'B3', 'B2']
+ )
+ rgb_bands = ('B4', 'B3', 'B2')
- metadata = {
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
'100k': {
'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1',
'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0',
@@ -73,7 +74,7 @@ def __init__(
root: Path = 'data',
version: str = '100k',
seasons: int = 1,
- bands: list[str] = rgb_bands,
+ bands: Sequence[str] = rgb_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py
index 8b8ee57803c..20183f06421 100644
--- a/torchgeo/datasets/sen12ms.py
+++ b/torchgeo/datasets/sen12ms.py
@@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset):
or manually downloaded from https://dataserv.ub.tum.de/s/m1474000
and https://github.com/schmitt-muc/SEN12MS/tree/master/splits.
This download will likely take several hours.
- """ # noqa: E501
+ """
- BAND_SETS: dict[str, tuple[str, ...]] = {
+ BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': (
'VV',
'VH',
@@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset):
'B12',
)
- rgb_bands = ['B04', 'B03', 'B02']
+ rgb_bands = ('B04', 'B03', 'B02')
- filenames = [
+ filenames = (
'ROIs1158_spring_lc.tar.gz',
'ROIs1158_spring_s1.tar.gz',
'ROIs1158_spring_s2.tar.gz',
@@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset):
'ROIs2017_winter_s2.tar.gz',
'train_list.txt',
'test_list.txt',
- ]
- light_filenames = [
+ )
+ light_filenames = (
'ROIs1158_spring',
'ROIs1868_summer',
'ROIs1970_fall',
'ROIs2017_winter',
'train_list.txt',
'test_list.txt',
- ]
- md5s = [
+ )
+ md5s = (
'6e2e8fa8b8cba77ddab49fd20ff5c37b',
'fba019bb27a08c1db96b31f718c34d79',
'd58af2c15a16f376eb3308dc9b685af2',
@@ -161,7 +162,7 @@ class SEN12MS(NonGeoDataset):
'3807545661288dcca312c9c538537b63',
'0a68d4e1eb24f128fccdb930000b2546',
'c7faad064001e646445c4c634169484d',
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py
index 163c771a6b9..3dc5fc31771 100644
--- a/torchgeo/datasets/sentinel.py
+++ b/torchgeo/datasets/sentinel.py
@@ -137,7 +137,7 @@ class Sentinel1(Sentinel):
\.
"""
date_format = '%Y%m%dT%H%M%S'
- all_bands = ['HH', 'HV', 'VV', 'VH']
+ all_bands = ('HH', 'HV', 'VV', 'VH')
separate_files = True
def __init__(
@@ -277,7 +277,7 @@ class Sentinel2(Sentinel):
date_format = '%Y%m%dT%H%M%S'
# https://gisgeography.com/sentinel-2-bands-combinations/
- all_bands = [
+ all_bands: tuple[str, ...] = (
'B01',
'B02',
'B03',
@@ -291,8 +291,8 @@ class Sentinel2(Sentinel):
'B10',
'B11',
'B12',
- ]
- rgb_bands = ['B04', 'B03', 'B02']
+ )
+ rgb_bands = ('B04', 'B03', 'B02')
separate_files = True
diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py
index 8a8882f3598..a3a842561fa 100644
--- a/torchgeo/datasets/skippd.py
+++ b/torchgeo/datasets/skippd.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset):
.. versionadded:: 0.5
"""
- url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501
- md5 = {
+ url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}'
+ md5: ClassVar[dict[str, str]] = {
'forecast': 'f4f3509ddcc83a55c433be9db2e51077',
'nowcast': '0000761d403e45bb5f86c21d3c69aa80',
}
@@ -71,9 +71,9 @@ class SKIPPD(NonGeoDataset):
data_file_name = '2017_2019_images_pv_processed_{}.hdf5'
zipfile_name = '2017_2019_images_pv_processed_{}.zip'
- valid_splits = ['trainval', 'test']
+ valid_splits = ('trainval', 'test')
- valid_tasks = ['nowcast', 'forecast']
+ valid_tasks = ('nowcast', 'forecast')
dateformat = '%m/%d/%Y, %H:%M:%S'
diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py
index e90e89b4e34..4840a48e468 100644
--- a/torchgeo/datasets/so2sat.py
+++ b/torchgeo/datasets/so2sat.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable, Sequence
-from typing import cast
+from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset):
This dataset requires the following additional library to be installed:
* ``_ to load the dataset
- """ # noqa: E501
+ """
- versions = ['2', '3_random', '3_block', '3_culture_10']
- filenames_by_version = {
+ versions = ('2', '3_random', '3_block', '3_culture_10')
+ filenames_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': {
'train': 'training.h5',
'validation': 'validation.h5',
@@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset):
'test': 'culture_10/testing.h5',
},
}
- md5s_by_version = {
+ md5s_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': {
'train': '702bc6a9368ebff4542d791e53469244',
'validation': '71cfa6795de3e22207229d06d6f8775d',
@@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset):
},
}
- classes = [
+ classes = (
'Compact high rise',
'Compact mid rise',
'Compact low rise',
@@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset):
'Bare rock or paved',
'Bare soil or sand',
'Water',
- ]
+ )
all_s1_band_names = (
'S1_B1',
@@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset):
)
all_band_names = all_s1_band_names + all_s2_band_names
- rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02']
+ rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02')
- BAND_SETS = {
+ BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': all_band_names,
's1': all_s1_band_names,
's2': all_s2_band_names,
diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py
index 688ce166924..7b79f44c8d2 100644
--- a/torchgeo/datasets/south_africa_crop_type.py
+++ b/torchgeo/datasets/south_africa_crop_type.py
@@ -6,8 +6,8 @@
import os
import pathlib
import re
-from collections.abc import Callable, Iterable
-from typing import Any, cast
+from collections.abc import Callable, Iterable, Sequence
+from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@@ -79,9 +79,9 @@ class SouthAfricaCropType(RasterDataset):
_10m
"""
date_format = '%Y_%m_%d'
- rgb_bands = ['B04', 'B03', 'B02']
- s1_bands = ['VH', 'VV']
- s2_bands = [
+ rgb_bands = ('B04', 'B03', 'B02')
+ s1_bands = ('VH', 'VV')
+ s2_bands = (
'B01',
'B02',
'B03',
@@ -94,9 +94,9 @@ class SouthAfricaCropType(RasterDataset):
'B09',
'B11',
'B12',
- ]
- all_bands: list[str] = s1_bands + s2_bands
- cmap = {
+ )
+ all_bands = s1_bands + s2_bands
+ cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),
@@ -113,8 +113,8 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
- classes: list[int] = list(cmap.keys()),
- bands: list[str] = s2_bands,
+ classes: Sequence[int] = list(cmap.keys()),
+ bands: Sequence[str] = s2_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:
diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py
index 1ee8d7cb79e..3ec4b559472 100644
--- a/torchgeo/datasets/south_america_soybean.py
+++ b/torchgeo/datasets/south_america_soybean.py
@@ -5,7 +5,7 @@
import pathlib
from collections.abc import Callable, Iterable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@@ -47,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset):
is_image = False
url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif'
- md5s = {
+ md5s: ClassVar[dict[int, str]] = {
2021: 'edff3ada13a1a9910d1fe844d28ae4f',
2020: '0709dec807f576c9707c8c7e183db31',
2019: '441836493bbcd5e123cff579a58f5a4f',
diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py
index 0ff02c55da4..21a31657f58 100644
--- a/torchgeo/datasets/spacenet.py
+++ b/torchgeo/datasets/spacenet.py
@@ -8,7 +8,7 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Callable
-from typing import Any
+from typing import Any, ClassVar
import fiona
import matplotlib.pyplot as plt
@@ -55,9 +55,9 @@ class SpaceNet(NonGeoDataset, ABC):
image_glob = '*.tif'
mask_glob = '*.geojson'
file_regex = r'_img(\d+)\.'
- chip_size: dict[str, tuple[int, int]] = {}
+ chip_size: ClassVar[dict[str, tuple[int, int]]] = {}
- cities = {
+ cities: ClassVar[dict[int, str]] = {
1: 'Rio',
2: 'Vegas',
3: 'Paris',
@@ -98,7 +98,7 @@ def valid_images(self) -> dict[str, list[str]]:
@property
@abstractmethod
- def valid_masks(self) -> list[str]:
+ def valid_masks(self) -> tuple[str, ...]:
"""List of valid masks."""
def __init__(
@@ -426,7 +426,7 @@ class SpaceNet1(SpaceNet):
directory_glob = '{product}'
dataset_id = 'SN1_buildings'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
1: [
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
@@ -441,7 +441,7 @@ class SpaceNet1(SpaceNet):
]
},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
1: [
'279e334a2120ecac70439ea246174516',
@@ -453,10 +453,16 @@ class SpaceNet1(SpaceNet):
1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8']
},
}
- valid_aois = {'train': [1], 'test': [1]}
- valid_images = {'train': ['3band', '8band'], 'test': ['3band', '8band']}
- valid_masks = ['geojson']
- chip_size = {'3band': (406, 439), '8band': (102, 110)}
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
+ 'train': ['3band', '8band'],
+ 'test': ['3band', '8band'],
+ }
+ valid_masks = ('geojson',)
+ chip_size: ClassVar[dict[str, tuple[int, int]]] = {
+ '3band': (406, 439),
+ '8band': (102, 110),
+ }
class SpaceNet2(SpaceNet):
@@ -522,7 +528,7 @@ class SpaceNet2(SpaceNet):
"""
dataset_id = 'SN2_buildings'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'],
3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'],
@@ -536,7 +542,7 @@ class SpaceNet2(SpaceNet):
5: ['AOI_5_Khartoum_Test_public.tar.gz'],
},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['307da318bc43aaf9481828f92eda9126'],
3: ['4db469e3e4e7bf025368ad730aec0888'],
@@ -550,13 +556,16 @@ class SpaceNet2(SpaceNet):
5: ['037d7be10530f0dd1c43d4ef79f3236e'],
},
}
- valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {
+ 'train': [2, 3, 4, 5],
+ 'test': [2, 3, 4, 5],
+ }
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
}
- valid_masks = [os.path.join('geojson', 'buildings')]
- chip_size = {'MUL': (163, 163)}
+ valid_masks = (os.path.join('geojson', 'buildings'),)
+ chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)}
class SpaceNet3(SpaceNet):
@@ -624,7 +633,7 @@ class SpaceNet3(SpaceNet):
"""
dataset_id = 'SN3_roads'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: [
'SN3_roads_train_AOI_2_Vegas.tar.gz',
@@ -650,7 +659,7 @@ class SpaceNet3(SpaceNet):
5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'],
},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'],
3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'],
@@ -664,12 +673,15 @@ class SpaceNet3(SpaceNet):
5: ['f367c79fa0fc1d38e63a0fdd065ed957'],
},
}
- valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {
+ 'train': [2, 3, 4, 5],
+ 'test': [2, 3, 4, 5],
+ }
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
}
- valid_masks = ['geojson_roads', 'geojson_roads_speed']
+ valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed')
class SpaceNet4(SpaceNet):
@@ -708,7 +720,7 @@ class SpaceNet4(SpaceNet):
directory_glob = os.path.join('**', '{product}')
file_regex = r'_(\d+_\d+)\.'
dataset_id = 'SN4_buildings'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
6: [
'Atlanta_nadir7_catid_1030010003D22F00.tar.gz',
@@ -743,7 +755,7 @@ class SpaceNet4(SpaceNet):
},
'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
6: [
'd41ab6ec087b07e1e046c55d1fa5754b',
@@ -778,12 +790,12 @@ class SpaceNet4(SpaceNet):
},
'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']},
}
- valid_aois = {'train': [6], 'test': [6]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'Pan-Sharpen'],
'test': ['MS', 'PAN', 'Pan-Sharpen'],
}
- valid_masks = [os.path.join('geojson', 'spacenet-buildings')]
+ valid_masks = (os.path.join('geojson', 'spacenet-buildings'),)
class SpaceNet5(SpaceNet3):
@@ -850,26 +862,26 @@ class SpaceNet5(SpaceNet3):
file_regex = r'_chip(\d+)\.'
dataset_id = 'SN5_roads'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'],
8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'],
},
'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
7: ['03082d01081a6d8df2bc5a9645148d2a'],
8: ['1ee20ba781da6cb7696eef9a95a5bdcc'],
},
'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']},
}
- valid_aois = {'train': [7, 8], 'test': [9]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
}
- valid_masks = ['geojson_roads_speed']
+ valid_masks = ('geojson_roads_speed',)
class SpaceNet6(SpaceNet):
@@ -937,20 +949,20 @@ class SpaceNet6(SpaceNet):
file_regex = r'_tile_(\d+)\.'
dataset_id = 'SN6_buildings'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']},
'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']},
'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']},
}
- valid_aois = {'train': [11], 'test': [11]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'],
'test': ['SAR-Intensity'],
}
- valid_masks = ['geojson_buildings']
+ valid_masks = ('geojson_buildings',)
class SpaceNet7(SpaceNet):
@@ -958,7 +970,7 @@ class SpaceNet7(SpaceNet):
`SpaceNet 7 `_ is a dataset which
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
- Planet Labs’ Dove constellation between 2017 and 2020. It includes ≈ 24
+ Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24
images (one per month) covering > 100 unique geographies, and comprises >
40,000 km2 of imagery and exhaustive polygon labels of building footprints
therein, totaling over 11M individual annotations.
@@ -993,18 +1005,24 @@ class SpaceNet7(SpaceNet):
mask_glob = '*_Buildings.geojson'
file_regex = r'global_monthly_(\d+.*\d+)'
dataset_id = 'SN7_buildings'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['SN7_buildings_train.tar.gz']},
'test': {0: ['SN7_buildings_test_public.tar.gz']},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']},
'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']},
}
- valid_aois = {'train': [0], 'test': [0]}
- valid_images = {'train': ['images', 'images_masked'], 'test': ['images_masked']}
- valid_masks = ['labels', 'labels_match', 'labels_match_pix']
- chip_size = {'images': (1024, 1024), 'images_masked': (1024, 1024)}
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
+ 'train': ['images', 'images_masked'],
+ 'test': ['images_masked'],
+ }
+ valid_masks = ('labels', 'labels_match', 'labels_match_pix')
+ chip_size: ClassVar[dict[str, tuple[int, int]]] = {
+ 'images': (1024, 1024),
+ 'images_masked': (1024, 1024),
+ }
class SpaceNet8(SpaceNet):
@@ -1024,7 +1042,7 @@ class SpaceNet8(SpaceNet):
directory_glob = '{product}'
file_regex = r'(\d+_\d+_\d+)\.'
dataset_id = 'SN8_floods'
- tarballs = {
+ tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
0: [
'Germany_Training_Public.tar.gz',
@@ -1033,16 +1051,19 @@ class SpaceNet8(SpaceNet):
},
'test': {0: ['Louisiana-West_Test_Public.tar.gz']},
}
- md5s = {
+ md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95']
},
'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']},
}
- valid_aois = {'train': [0], 'test': [0]}
- valid_images = {
+ valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
+ valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PRE-event', 'POST-event'],
'test': ['PRE-event', 'POST-event'],
}
- valid_masks = ['annotations']
- chip_size = {'PRE-event': (1300, 1300), 'POST-event': (1300, 1300)}
+ valid_masks = ('annotations',)
+ chip_size: ClassVar[dict[str, tuple[int, int]]] = {
+ 'PRE-event': (1300, 1300),
+ 'POST-event': (1300, 1300),
+ }
diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py
index a087cbf68dc..b2840afb865 100644
--- a/torchgeo/datasets/ssl4eo.py
+++ b/torchgeo/datasets/ssl4eo.py
@@ -7,7 +7,7 @@
import os
import random
from collections.abc import Callable
-from typing import TypedDict
+from typing import ClassVar, TypedDict
import matplotlib.pyplot as plt
import numpy as np
@@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5
- """ # noqa: E501
+ """
class _Metadata(TypedDict):
num_bands: int
rgb_bands: list[int]
- metadata: dict[str, _Metadata] = {
+ metadata: ClassVar[dict[str, _Metadata]] = {
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
@@ -107,8 +107,8 @@ class _Metadata(TypedDict):
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
}
- url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501
- checksums = {
+ url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
+ checksums: ClassVar[dict[str, dict[str, str]]] = {
'tm_toa': {
'aa': '553795b8d73aa253445b1e67c5b81f11',
'ab': 'e9e0739b5171b37d16086cb89ab370e8',
@@ -357,7 +357,7 @@ class _Metadata(TypedDict):
md5: str
bands: list[str]
- metadata: dict[str, _Metadata] = {
+ metadata: ClassVar[dict[str, _Metadata]] = {
's1': {
'filename': 's1.tar.gz',
'md5': '51ee23b33eb0a2f920bda25225072f3a',
diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py
index 3f16bf33b07..fc9bb3883d3 100644
--- a/torchgeo/datasets/ssl4eo_benchmark.py
+++ b/torchgeo/datasets/ssl4eo_benchmark.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5
- """ # noqa: E501
+ """
- url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz'
- valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr']
- valid_products = ['cdl', 'nlcd']
- valid_splits = ['train', 'val', 'test']
+ valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr')
+ valid_products = ('cdl', 'nlcd')
+ valid_splits = ('train', 'val', 'test')
image_root = 'ssl4eo_l_{}_benchmark'
- img_md5s = {
+ img_md5s: ClassVar[dict[str, str]] = {
'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15',
'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f',
'etm_sr': '34a24fa89a801654f8d01e054662c8cd',
@@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15',
}
- mask_dir_dict = {
+ mask_dir_dict: ClassVar[dict[str, str]] = {
'tm_toa': 'ssl4eo_l_tm_{}',
'etm_toa': 'ssl4eo_l_etm_{}',
'etm_sr': 'ssl4eo_l_etm_{}',
'oli_tirs_toa': 'ssl4eo_l_oli_{}',
'oli_sr': 'ssl4eo_l_oli_{}',
}
- mask_md5s = {
+ mask_md5s: ClassVar[dict[str, dict[str, str]]] = {
'tm': {
'cdl': '3d676770ffb56c7e222a7192a652a846',
'nlcd': '261149d7614fcfdcb3be368eefa825c7',
@@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
},
}
- year_dict = {
+ year_dict: ClassVar[dict[str, int]] = {
'tm_toa': 2011,
'etm_toa': 2019,
'etm_sr': 2019,
@@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': 2019,
}
- rgb_indices = {
+ rgb_indices: ClassVar[dict[str, list[int]]] = {
'tm_toa': [2, 1, 0],
'etm_toa': [2, 1, 0],
'etm_sr': [2, 1, 0],
@@ -101,9 +102,12 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': [3, 2, 1],
}
- split_percentages = [0.7, 0.15, 0.15]
+ split_percentages = (0.7, 0.15, 0.15)
- cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap}
+ cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = {
+ 'nlcd': NLCD.cmap,
+ 'cdl': CDL.cmap,
+ }
def __init__(
self,
diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py
index 604c17bc70e..eec9be57ab3 100644
--- a/torchgeo/datasets/sustainbench_crop_yield.py
+++ b/torchgeo/datasets/sustainbench_crop_yield.py
@@ -45,17 +45,17 @@ class SustainBenchCropYield(NonGeoDataset):
* https://doi.org/10.1609/aaai.v31i1.11172
.. versionadded:: 0.5
- """ # noqa: E501
+ """
- valid_countries = ['usa', 'brazil', 'argentina']
+ valid_countries = ('usa', 'brazil', 'argentina')
md5 = '362bad07b51a1264172b8376b39d1fc9'
- url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501
+ url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link'
dir = 'soybeans'
- valid_splits = ['train', 'dev', 'test']
+ valid_splits = ('train', 'dev', 'test')
def __init__(
self,
diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py
index 7045ee558a8..5527a7ed133 100644
--- a/torchgeo/datasets/ucmerced.py
+++ b/torchgeo/datasets/ucmerced.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable
-from typing import cast
+from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset):
* https://dl.acm.org/doi/10.1145/1869790.1869829
"""
- url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501
+ url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip'
filename = 'UCMerced_LandUse.zip'
md5 = '5b7ec56793786b6dc8a908e8854ac0e4'
base_dir = os.path.join('UCMerced_LandUse', 'Images')
- splits = ['train', 'val', 'test']
- split_urls = {
- 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501
- 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501
- 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501
+ splits = ('train', 'val', 'test')
+ split_urls: ClassVar[dict[str, str]] = {
+ 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt',
+ 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt',
+ 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt',
}
- split_md5s = {
+ split_md5s: ClassVar[dict[str, str]] = {
'train': 'f2fb12eb2210cfb53f93f063a35ff374',
'val': '11ecabfc52782e5ea6a9c7c0d263aca0',
'test': '046aff88472d8fc07c4678d03749e28d',
diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py
index 42443c56d2d..db13059bfa7 100644
--- a/torchgeo/datasets/usavars.py
+++ b/torchgeo/datasets/usavars.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -49,12 +50,12 @@ class USAVars(NonGeoDataset):
.. versionadded:: 0.3
"""
- data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501
+ data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}'
dirname = 'uar'
md5 = '677e89fd20e5dd0fe4d29b61827c2456'
- label_urls = {
+ label_urls: ClassVar[dict[str, str]] = {
'housing': data_url.format('housing.csv'),
'income': data_url.format('income.csv'),
'roads': data_url.format('roads.csv'),
@@ -64,7 +65,7 @@ class USAVars(NonGeoDataset):
'treecover': data_url.format('treecover.csv'),
}
- split_metadata = {
+ split_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': data_url.format('train_split.txt'),
'filename': 'train_split.txt',
@@ -82,7 +83,7 @@ class USAVars(NonGeoDataset):
},
}
- ALL_LABELS = ['treecover', 'elevation', 'population']
+ ALL_LABELS = ('treecover', 'elevation', 'population')
def __init__(
self,
diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py
index aa8da3a8c82..73861011a7b 100644
--- a/torchgeo/datasets/utils.py
+++ b/torchgeo/datasets/utils.py
@@ -86,11 +86,11 @@ def __post_init__(self) -> None:
# https://github.com/PyCQA/pydocstyle/issues/525
@overload
- def __getitem__(self, key: int) -> float: # noqa: D105
+ def __getitem__(self, key: int) -> float:
pass
@overload
- def __getitem__(self, key: slice) -> list[float]: # noqa: D105
+ def __getitem__(self, key: slice) -> list[float]:
pass
def __getitem__(self, key: int | slice) -> float | list[float]:
@@ -289,7 +289,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[byt
The completed process.
"""
kwargs['check'] = True
- return subprocess.run((self.name,) + args, **kwargs)
+ return subprocess.run((self.name, *args), **kwargs)
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
@@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks(
def rgb_to_mask(
- rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]]
+ rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]]
) -> np.typing.NDArray[np.uint8]:
"""Converts an RGB colormap mask to a integer mask.
diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py
index 305eb950197..2c671ca27ac 100644
--- a/torchgeo/datasets/vaihingen.py
+++ b/torchgeo/datasets/vaihingen.py
@@ -5,6 +5,7 @@
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -55,15 +56,15 @@ class Vaihingen2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2
- """ # noqa: E501
+ """
- filenames = [
+ filenames = (
'ISPRS_semantic_labeling_Vaihingen.zip',
'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip',
- ]
- md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277']
+ )
+ md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277')
image_root = 'top'
- splits = {
+ splits: ClassVar[dict[str, list[str]]] = {
'train': [
'top_mosaic_09cm_area1.tif',
'top_mosaic_09cm_area11.tif',
@@ -102,22 +103,22 @@ class Vaihingen2D(NonGeoDataset):
'top_mosaic_09cm_area29.tif',
],
}
- classes = [
+ classes = (
'Clutter/background',
'Impervious surfaces',
'Building',
'Low Vegetation',
'Tree',
'Car',
- ]
- colormap = [
+ )
+ colormap = (
(255, 0, 0),
(255, 255, 255),
(0, 0, 255),
(0, 255, 255),
(0, 255, 0),
(255, 255, 0),
- ]
+ )
def __init__(
self,
@@ -258,7 +259,7 @@ def plot(
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
- sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
+ sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
@@ -266,7 +267,7 @@ def plot(
sample['image'][:3],
sample['prediction'],
alpha=alpha,
- colors=self.colormap,
+ colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py
index adce27494af..4840a788604 100644
--- a/torchgeo/datasets/vhr10.py
+++ b/torchgeo/datasets/vhr10.py
@@ -5,7 +5,7 @@
import os
from collections.abc import Callable
-from typing import Any
+from typing import Any, ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -158,18 +158,18 @@ class VHR10(NonGeoDataset):
``annotations.json`` file for the "positive" image set
"""
- image_meta = {
+ image_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip',
'filename': 'NWPU VHR-10 dataset.zip',
'md5': '6add6751469c12dd8c8d6223064c6c4d',
}
- target_meta = {
+ target_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json',
'filename': 'annotations.json',
'md5': '7c76ec50c17a61bb0514050d20f22c08',
}
- categories = [
+ categories = (
'background',
'airplane',
'ships',
@@ -181,7 +181,7 @@ class VHR10(NonGeoDataset):
'harbor',
'bridge',
'vehicle',
- ]
+ )
def __init__(
self,
diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py
index 136689894f1..fe51f6ade8f 100644
--- a/torchgeo/datasets/western_usa_live_fuel_moisture.py
+++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py
@@ -6,7 +6,7 @@
import glob
import json
import os
-from collections.abc import Callable
+from collections.abc import Callable, Iterable
from typing import Any
import pandas as pd
@@ -53,7 +53,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
label_name = 'percent(t)'
- all_variable_names = [
+ all_variable_names = (
# "date",
'slope(t)',
'elevation(t)',
@@ -193,12 +193,12 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
'vh_vv(t-3)',
'lat',
'lon',
- ]
+ )
def __init__(
self,
root: Path = 'data',
- input_features: list[str] = all_variable_names,
+ input_features: Iterable[str] = all_variable_names,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
) -> None:
@@ -273,7 +273,7 @@ def _load_data(self) -> pd.DataFrame:
data_rows.append(data_dict)
df = pd.DataFrame(data_rows)
- df = df[self.input_features + [self.label_name]]
+ df = df[[*self.input_features, self.label_name]]
return df
def _verify(self) -> None:
diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py
index 0cb66cba9f3..a7f6a36456a 100644
--- a/torchgeo/datasets/xview.py
+++ b/torchgeo/datasets/xview.py
@@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
+from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@@ -54,7 +55,7 @@ class XView2(NonGeoDataset):
.. versionadded:: 0.2
"""
- metadata = {
+ metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'train_images_labels_targets.tar.gz',
'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16',
@@ -66,8 +67,8 @@ class XView2(NonGeoDataset):
'directory': 'test',
},
}
- classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed']
- colormap = ['green', 'blue', 'orange', 'red']
+ classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed')
+ colormap = ('green', 'blue', 'orange', 'red')
def __init__(
self,
@@ -242,10 +243,16 @@ def plot(
"""
ncols = 2
image1 = draw_semantic_segmentation_masks(
- sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap
+ sample['image'][0],
+ sample['mask'][0],
+ alpha=alpha,
+ colors=list(self.colormap),
)
image2 = draw_semantic_segmentation_masks(
- sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap
+ sample['image'][1],
+ sample['mask'][1],
+ alpha=alpha,
+ colors=list(self.colormap),
)
if 'prediction' in sample: # NOTE: this assumes predictions are made for post
ncols += 1
@@ -253,7 +260,7 @@ def plot(
sample['image'][1],
sample['prediction'],
alpha=alpha,
- colors=self.colormap,
+ colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py
index 394721b5237..2928dc58a70 100644
--- a/torchgeo/datasets/zuericrop.py
+++ b/torchgeo/datasets/zuericrop.py
@@ -52,15 +52,15 @@ class ZueriCrop(NonGeoDataset):
* `h5py `_ to load the dataset
"""
- urls = [
+ urls = (
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
- 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501
- ]
- md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
- filenames = ['ZueriCrop.hdf5', 'labels.csv']
+ 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv',
+ )
+ md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b')
+ filenames = ('ZueriCrop.hdf5', 'labels.csv')
band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12')
- rgb_bands = ['B04', 'B03', 'B02']
+ rgb_bands = ('B04', 'B03', 'B02')
def __init__(
self,
diff --git a/torchgeo/main.py b/torchgeo/main.py
index b403d4fa50c..48e84b6e8cf 100644
--- a/torchgeo/main.py
+++ b/torchgeo/main.py
@@ -8,7 +8,7 @@
from lightning.pytorch.cli import ArgsType, LightningCLI
# Allows classes to be referenced using only the class name
-import torchgeo.datamodules # noqa: F401
+import torchgeo.datamodules
import torchgeo.trainers # noqa: F401
from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask
diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py
index 9e214d9d04e..1caf32a47a4 100644
--- a/torchgeo/models/api.py
+++ b/torchgeo/models/api.py
@@ -8,7 +8,7 @@
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
-""" # noqa: E501
+"""
from collections.abc import Callable
from typing import Any
diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py
index 32f0be01a61..e82fbd574bd 100644
--- a/torchgeo/models/dofa.py
+++ b/torchgeo/models/dofa.py
@@ -384,7 +384,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc]
"""
DOFA_MAE = Weights(
- url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501
+ url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth',
transforms=_dofa_transforms,
meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
@@ -403,7 +403,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
"""
DOFA_MAE = Weights(
- url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501
+ url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth',
transforms=_dofa_transforms,
meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py
index 82545afa339..079b1222867 100644
--- a/torchgeo/models/rcf.py
+++ b/torchgeo/models/rcf.py
@@ -140,7 +140,7 @@ def _normalize(
a numpy array of size (N, C, H, W) containing the normalized patches
.. versionadded:: 0.5
- """ # noqa: E501
+ """
n_patches = patches.shape[0]
orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1)
diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py
index 05d95ceb8f1..c62c71dea6a 100644
--- a/torchgeo/models/resnet.py
+++ b/torchgeo/models/resnet.py
@@ -11,8 +11,8 @@
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum
-# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
-# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
+# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
+# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
@@ -22,7 +22,7 @@
)
# Normalization only available for RGB dataset, defined here:
-# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
+# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
@@ -37,7 +37,7 @@
)
# Normalization only available for RGB dataset, defined here:
-# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501
+# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_gassl_transforms = K.AugmentationSequential(
@@ -47,7 +47,7 @@
data_keys=None,
)
-# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
+# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
@@ -70,7 +70,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
"""
LANDSAT_TM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -83,7 +83,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -96,7 +96,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -109,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -122,7 +122,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -135,7 +135,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -148,7 +148,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -161,7 +161,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -174,7 +174,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -187,7 +187,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -200,7 +200,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
- url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -213,7 +213,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_MOCO = Weights(
- url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -226,7 +226,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
- url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth',
transforms=_seco_transforms,
meta={
'dataset': 'SeCo Dataset',
@@ -249,7 +249,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
"""
FMOW_RGB_GASSL = Weights(
- url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth',
transforms=_gassl_transforms,
meta={
'dataset': 'fMoW Dataset',
@@ -262,7 +262,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -275,7 +275,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -288,7 +288,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -301,7 +301,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -314,7 +314,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -327,7 +327,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -340,7 +340,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -353,7 +353,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -366,7 +366,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -379,7 +379,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -392,7 +392,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL1_ALL_MOCO = Weights(
- url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -405,7 +405,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_DINO = Weights(
- url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -418,7 +418,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
- url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -431,7 +431,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_MOCO = Weights(
- url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -444,7 +444,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
- url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501
+ url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth',
transforms=_seco_transforms,
meta={
'dataset': 'SeCo Dataset',
diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py
index 21df8dedd91..f29c2ffabcb 100644
--- a/torchgeo/models/swin.py
+++ b/torchgeo/models/swin.py
@@ -12,20 +12,20 @@
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum
-# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
+# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
-# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
-# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
+# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
+# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
)
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
-# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
-# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501
+# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
+# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
_std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
-) # noqa: E501
+)
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
@@ -33,7 +33,7 @@
data_keys=None,
)
-# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
+# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
_landsat_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
@@ -56,7 +56,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
"""
NAIP_RGB_SI_SATLAS = Weights(
- url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501
+ url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@@ -68,7 +68,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SI_SATLAS = Weights(
- url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501
+ url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@@ -80,7 +80,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_MS_SI_SATLAS = Weights(
- url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501
+ url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
transforms=_sentinel2_ms_satlas_transforms,
meta={
'dataset': 'Satlas',
@@ -93,7 +93,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL1_SI_SATLAS = Weights(
- url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501
+ url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@@ -106,7 +106,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_SI_SATLAS = Weights(
- url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501
+ url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
transforms=_landsat_satlas_transforms,
meta={
'dataset': 'Satlas',
@@ -126,7 +126,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
'B09',
'B10',
'B11',
- ], # noqa: E501
+ ],
},
)
diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py
index a81ac13d48a..1878883f484 100644
--- a/torchgeo/models/vit.py
+++ b/torchgeo/models/vit.py
@@ -11,8 +11,8 @@
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum
-# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
-# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
+# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
+# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
@@ -21,7 +21,7 @@
data_keys=None,
)
-# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
+# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
@@ -44,7 +44,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
"""
LANDSAT_TM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -57,7 +57,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -70,7 +70,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -83,7 +83,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -96,7 +96,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -109,7 +109,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -122,7 +122,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -135,7 +135,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -148,7 +148,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -161,7 +161,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
- url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501
+ url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@@ -174,7 +174,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_DINO = Weights(
- url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501
+ url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@@ -187,7 +187,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
- url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501
+ url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py
index 87484e75730..b2210eb2518 100644
--- a/torchgeo/transforms/transforms.py
+++ b/torchgeo/transforms/transforms.py
@@ -53,7 +53,7 @@ def __init__(
else:
keys.append(key)
- self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501
+ self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.