Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff: enable ruff-specific rules #2218

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))

import torchgeo # noqa: E402
import torchgeo

# -- Project information -----------------------------------------------------

Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
" rgb_bands = ['B04', 'B03', 'B02']"
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ('B04', 'B03', 'B02')"
]
},
{
Expand Down Expand Up @@ -432,8 +432,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
" rgb_bands = ['B04', 'B03', 'B02']\n",
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ('B04', 'B03', 'B02')\n",
"\n",
" def plot(self, sample):\n",
" # Find the correct band index order\n",
Expand Down
2 changes: 1 addition & 1 deletion experiments/ssl4eo/download_ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def filter_collection(

if filtered.size().getInfo() == 0:
raise ee.EEException(
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.'
)
return filtered

Expand Down
2 changes: 1 addition & 1 deletion experiments/ssl4eo/sample_ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
def get_world_cities(
download_root: str = 'world_cities', size: int = 10000
) -> pd.DataFrame:
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip'
filename = 'worldcities.csv'
download_and_extract_archive(url, download_root)
cols = ['city', 'lat', 'lng', 'population']
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ quote-style = "single"
skip-magic-trailing-comma = true

[tool.ruff.lint]
extend-select = ["ANN", "D", "I", "NPY201", "UP"]
extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"]
ignore = ["ANN101", "ANN102", "ANN401"]

[tool.ruff.lint.per-file-ignores]
Expand Down
28 changes: 14 additions & 14 deletions tests/data/dfc2022/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@

train_set = [
{
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif',
},
{
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif',
},
]

unlabeled_set = [
{
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif',
},
{
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif',
},
]

val_set = [
{
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif',
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif',
},
{
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif',
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif',
},
]

Expand Down
6 changes: 3 additions & 3 deletions tests/data/seasonet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
# Compute checksums
with open(archive, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{season}: {repr(md5)}')
print(f'{season}: {md5!r}')

# Write meta.csv
with open('meta.csv', 'w') as f:
Expand All @@ -121,7 +121,7 @@
# Compute checksums
with open('meta.csv', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'meta.csv: {repr(md5)}')
print(f'meta.csv: {md5!r}')

os.makedirs('splits', exist_ok=True)

Expand All @@ -138,4 +138,4 @@
# Compute checksums
with open('splits.zip', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'splits: {repr(md5)}')
print(f'splits: {md5!r}')
2 changes: 1 addition & 1 deletion tests/datasets/test_eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def test_invalid_query(self, dataset: EuroCrops) -> None:
dataset[query]

def test_integrity_error(self, dataset: EuroCrops) -> None:
dataset.zenodo_files = [('AA.zip', 'invalid')]
dataset.zenodo_files = (('AA.zip', 'invalid'),)
assert not dataset._check_integrity()
4 changes: 2 additions & 2 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset):


class CustomSentinelDataset(Sentinel2):
all_bands: list[str] = []
all_bands: tuple[str, ...] = ()
separate_files = False


Expand Down Expand Up @@ -356,7 +356,7 @@ def test_no_data(self, tmp_path: Path) -> None:

def test_no_all_bands(self) -> None:
root = os.path.join('tests', 'data', 'sentinel2')
bands = ['B04', 'B03', 'B02']
bands = ('B04', 'B03', 'B02')
transforms = nn.Identity()
cache = True
msg = (
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])

@pytest.fixture
def weights(self) -> WeightsEnum:
Expand Down
12 changes: 6 additions & 6 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down Expand Up @@ -259,13 +259,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down
6 changes: 3 additions & 3 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down
6 changes: 3 additions & 3 deletions tests/trainers/test_iobench.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None:
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass
2 changes: 1 addition & 1 deletion tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])

def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):
Expand Down
12 changes: 6 additions & 6 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down Expand Up @@ -237,13 +237,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down
6 changes: 3 additions & 3 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_trainer(
'1',
]

main(['fit'] + args)
main(['fit', *args])

def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/seco.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
seasons = kwargs.get('seasons', 1)

# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""So2Sat datamodule."""

from typing import Any
from typing import Any, ClassVar

import torch
from torch import Generator, Tensor
Expand All @@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule):
"train" set and use the "test" set as the test set.
"""

means_per_version: dict[str, Tensor] = {
means_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
-0.00003591224260,
Expand Down Expand Up @@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule):
}
means_per_version['3_culture_10'] = means_per_version['2']

stds_per_version: dict[str, Tensor] = {
stds_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
0.17555201,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""

# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
mean = torch.tensor(0)
std = torch.tensor(10000)

Expand Down
16 changes: 8 additions & 8 deletions torchgeo/datasets/advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset):
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
"""

urls = [
urls = (
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
]
filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip']
md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31']
directories = ['vision', 'sound']
classes = [
)
filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip')
md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31')
directories = ('vision', 'sound')
classes: tuple[str, ...] = (
'airport',
'beach',
'bridge',
Expand All @@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset):
'sparse shrub land',
'sports land',
'train station',
]
)

def __init__(
self,
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(
raise DatasetNotFoundError(self)

self.files = self._load_files(self.root)
self.classes = sorted({f['cls'] for f in self.files})
self.classes = tuple(sorted({f['cls'] for f in self.files}))
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}

def __getitem__(self, index: int) -> dict[str, Tensor]:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):

is_image = False

url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326'

base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'

Expand Down
Loading