diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0ee2e3420fb..89ab4681f40 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -15,8 +15,6 @@ jobs: latest: name: latest runs-on: ${{ matrix.os }} - env: - MPLBACKEND: Agg strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] @@ -55,8 +53,6 @@ jobs: minimum: name: minimum runs-on: ubuntu-latest - env: - MPLBACKEND: Agg steps: - name: Clone repo uses: actions/checkout@v4.1.7 @@ -90,8 +86,6 @@ jobs: datasets: name: datasets runs-on: ubuntu-latest - env: - MPLBACKEND: Agg steps: - name: Clone repo uses: actions/checkout@v4.1.7 diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 89f59a6a2de..72c55139c05 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -282,6 +282,11 @@ FAIR1M .. autoclass:: FAIR1M +Fields Of The World +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: FieldsOfTheWorld + FireRisk ^^^^^^^^ diff --git a/docs/api/datasets/geo_datasets.csv b/docs/api/datasets/geo_datasets.csv index 0f7870efe06..4bb5788609e 100644 --- a/docs/api/datasets/geo_datasets.csv +++ b/docs/api/datasets/geo_datasets.csv @@ -20,7 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30" `LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5 `Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30 -`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",1 +`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2 `NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10 `NLCD`_,Masks,Landsat,"public domain",-,30 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 03526ed5c5f..abdd41cc1f8 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI `FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB +`Fields Of The World`_,"S,I",Sentinel-2,"Various","70795","2,3",256x256,10,MSI `FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB `Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB `GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM" diff --git a/pyproject.toml b/pyproject.toml index 83062ce28fe..687843521e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,8 @@ datasets = [ "laspy>=2", # opencv-python 4.5.4+ required for Python 3.10 wheels "opencv-python>=4.5.4", + # pandas 2+ required for parquet extra + "pandas[parquet]>=2", # pycocotools 2.0.7+ required for wheels "pycocotools>=2.0.7", # pyvista 0.34.2+ required to avoid ImportError in CI diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 5ee75b87ccd..b05a3c86894 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,7 +1,8 @@ # datasets -h5py==3.11.0 +h5py==3.12.1 laspy==2.5.4 opencv-python==4.10.0.84 +pandas[parquet]==2.2.3 pycocotools==2.0.8 pyvista==0.44.1 scikit-image==0.24.0 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 24e15ba962a..65a1439f80c 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -28,6 +28,7 @@ h5py==3.6.0 laspy==2.0.0 opencv-python==4.5.4.58 pycocotools==2.0.7 +pyarrow==15.0.0 # Remove when we upgrade min verison of pandas to `pandas[parquet]>=2` pyvista==0.34.2 scikit-image==0.19.0 scipy==1.7.2 diff --git a/requirements/style.txt b/requirements/style.txt index 648a2033db5..41430302bc4 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style mypy==1.11.2 -ruff==0.6.7 +ruff==0.6.8 diff --git a/tests/conftest.py b/tests/conftest.py index 1f5c09a8bb0..d55a972ced1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any +import matplotlib import pytest import torch import torchvision @@ -19,6 +20,11 @@ def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) +@pytest.fixture(autouse=True, scope='session') +def matplotlib_backend() -> None: + matplotlib.use('agg') + + @pytest.fixture(autouse=True) def torch_hub(tmp_path: Path) -> None: torch.hub.set_dir(tmp_path) # type: ignore[no-untyped-call] diff --git a/tests/data/ftw/austria.zip b/tests/data/ftw/austria.zip new file mode 100644 index 00000000000..e8b01db1b11 Binary files /dev/null and b/tests/data/ftw/austria.zip differ diff --git a/tests/data/ftw/data.py b/tests/data/ftw/data.py new file mode 100755 index 00000000000..8ffff19d6ad --- /dev/null +++ b/tests/data/ftw/data.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import zipfile + +import numpy as np +import pandas as pd +import rasterio +from affine import Affine + +np.random.seed(0) + +country = 'austria' +SIZE = 32 +num_samples = {'train': 2, 'val': 2, 'test': 2} +BASE_PROFILE = { + 'driver': 'GTiff', + 'dtype': 'uint16', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 4, + 'crs': 'EPSG:4326', + 'transform': Affine(5.4e-05, 0.0, 0, 0.0, 5.4e-05, 0), + 'blockxsize': SIZE, + 'blockysize': SIZE, + 'tiled': True, + 'interleave': 'pixel', +} + + +def create_image(fn: str) -> None: + os.makedirs(os.path.dirname(fn), exist_ok=True) + + profile = BASE_PROFILE.copy() + + data = np.random.randint(0, 20000, size=(4, SIZE, SIZE), dtype=np.uint16) + with rasterio.open(fn, 'w', **profile) as dst: + dst.write(data) + + +def create_mask(fn: str, min_val: int, max_val: int) -> None: + os.makedirs(os.path.dirname(fn), exist_ok=True) + + profile = BASE_PROFILE.copy() + profile['dtype'] = 'uint8' + profile['nodata'] = 0 + profile['count'] = 1 + + data = np.random.randint(min_val, max_val, size=(1, SIZE, SIZE), dtype=np.uint8) + with rasterio.open(fn, 'w', **profile) as dst: + dst.write(data) + + +if __name__ == '__main__': + i = 0 + cols = {'aoi_id': [], 'split': []} + for split, n in num_samples.items(): + for j in range(n): + aoi = f'g_{i}' + cols['aoi_id'].append(aoi) + cols['split'].append(split) + + create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif')) + create_image(os.path.join(country, 's2_images', 'window_b', f'{aoi}.tif')) + + create_mask( + os.path.join(country, 'label_masks', 'semantic_2class', f'{aoi}.tif'), + 0, + 1, + ) + create_mask( + os.path.join(country, 'label_masks', 'semantic_3class', f'{aoi}.tif'), + 0, + 2, + ) + create_mask( + os.path.join(country, 'label_masks', 'instance', f'{aoi}.tif'), 0, 100 + ) + + i += 1 + + # Create an extra train file to test for missing other files + aoi = f'g_{i}' + cols['aoi_id'].append(aoi) + cols['split'].append(split) + create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif')) + + # Write parquet index + df = pd.DataFrame(cols) + df.to_parquet(os.path.join(country, f'chips_{country}.parquet')) + + # archive to zip + with zipfile.ZipFile(f'{country}.zip', 'w') as zipf: + for root, _, files in os.walk(country): + for file in files: + output_fn = os.path.join(root, file) + zipf.write(output_fn, os.path.relpath(output_fn, country)) + + shutil.rmtree(country) + + # Compute checksums + with open(f'{country}.zip', 'rb') as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f'{md5}') diff --git a/tests/datasets/test_ftw.py b/tests/datasets/test_ftw.py new file mode 100644 index 00000000000..1a3d7130795 --- /dev/null +++ b/tests/datasets/test_ftw.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torch.utils.data import ConcatDataset + +from torchgeo.datasets import DatasetNotFoundError, FieldsOfTheWorld + +pytest.importorskip('pyarrow') + + +class TestFieldsOfTheWorld: + @pytest.fixture( + params=product(['train', 'val', 'test'], ['2-class', '3-class', 'instance']) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> FieldsOfTheWorld: + split, task = request.param + + monkeypatch.setattr(FieldsOfTheWorld, 'valid_countries', ['austria']) + monkeypatch.setattr( + FieldsOfTheWorld, + 'country_to_md5', + {'austria': '1cf9593c9bdceeaba21bbcb24d35816c'}, + ) + base_url = os.path.join('tests', 'data', 'ftw') + '/' + monkeypatch.setattr(FieldsOfTheWorld, 'base_url', base_url) + root = tmp_path + transforms = nn.Identity() + return FieldsOfTheWorld( + root, + split, + task, + countries='austria', + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: FieldsOfTheWorld) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + + def test_len(self, dataset: FieldsOfTheWorld) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: FieldsOfTheWorld) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + assert len(ds) == 4 + + def test_already_extracted(self, dataset: FieldsOfTheWorld) -> None: + FieldsOfTheWorld(root=dataset.root, download=True) + + def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + url = os.path.join('tests', 'data', 'ftw', 'austria.zip') + root = tmp_path + shutil.copy(url, root) + FieldsOfTheWorld(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + FieldsOfTheWorld(tmp_path) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + FieldsOfTheWorld(split='foo') + + def test_plot(self, dataset: FieldsOfTheWorld) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle='Test') + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x['prediction'] = x['mask'].clone() + dataset.plot(x) + plt.close() diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 6c0bc6790f1..c5a56d1808a 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -80,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None: def test_list_models() -> None: models = [builder.__name__ for builder in builders] assert set(models) == set(list_models()) + + +def test_invalid_model() -> None: + with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'): + get_weight('bad_model') diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index ed174b71e90..5c4fea89700 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -46,6 +46,7 @@ from .fair1m import FAIR1M from .fire_risk import FireRisk from .forestdamage import ForestDamage +from .ftw import FieldsOfTheWorld from .gbif import GBIF from .geo import ( GeoDataset, @@ -217,6 +218,7 @@ 'EuroSATSpatial', 'EuroSAT100', 'FAIR1M', + 'FieldsOfTheWorld', 'FireRisk', 'ForestDamage', 'GeoNRW', diff --git a/torchgeo/datasets/ftw.py b/torchgeo/datasets/ftw.py new file mode 100644 index 00000000000..7d4d92273d8 --- /dev/null +++ b/torchgeo/datasets/ftw.py @@ -0,0 +1,362 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Fields Of The World dataset.""" + +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import einops +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, array_to_tensor, download_and_extract_archive, extract_archive + + +class FieldsOfTheWorld(NonGeoDataset): + """Fields Of The World dataset. + + The `Fields Of The World `__ + datataset is a semantic and instance segmentation dataset for delineating field + boundaries. + + Dataset features: + + * 70462 patches across 24 countries + * Each country has a train, val, and test split + * Semantic segmentations masks with and without the field boundary class + * Instance segmentation masks + + Dataset format: + + * images are four-channel GeoTIFFs with dimension 256x256 + * segmentation masks (both two and three class) are single-channel GeoTIFFs + * instance masks are single-channel GeoTIFFs + + Dataset classes: + + 1. background + 2. field + 3. field-boundary (three-class only) + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.48550/arXiv.2409.16252 + + .. versionadded:: 0.7 + """ + + splits = ('train', 'val', 'test') + targets = ('2-class', '3-class', 'instance') + + valid_countries = ( + 'austria', + 'belgium', + 'brazil', + 'cambodia', + 'corsica', + 'croatia', + 'denmark', + 'estonia', + 'finland', + 'france', + 'germany', + 'india', + 'kenya', + 'latvia', + 'lithuania', + 'luxembourg', + 'netherlands', + 'portugal', + 'rwanda', + 'slovakia', + 'slovenia', + 'south_africa', + 'spain', + 'sweden', + 'vietnam', + ) + + base_url = 'https://data.source.coop/kerner-lab/fields-of-the-world-archive/' + + country_to_md5: ClassVar[dict[str, str]] = { + 'austria': '35604e3e3e78b4469e443bc756e19d26', + 'belgium': '111a9048e15391c947bc778e576e99b4', + 'brazil': '2ba96f9f01f37ead1435406c3f2b7c63', + 'cambodia': '581e9b8dae9713e4d03459bcec3c0bd0', + 'corsica': '0b38846063a98a31747fdeaf1ba03980', + 'croatia': 'dc5d33e19ae9e587c97f8f4c9852c87e', + 'denmark': 'ec817210b06351668cacdbd1a8fb9471', + 'estonia': 'b9c89e559e3c7d53a724e7f32ccf88ea', + 'finland': '23f853d6cbaea5a3596d1d38cc27fd65', + 'france': 'f05314f148642ff72d8bea903c01802d', + 'germany': 'd57a7ed203b9cf89c709aab29d687cee', + 'india': '361a688507e2e5cc7ca7138be01a5b80', + 'kenya': '80ca0335b25440379f99b7011dfbdfa2', + 'latvia': '6eeaaa57cdf18f25497f84e854a86d42', + 'lithuania': '0a2f4ab3309633e2de121d936e0763ba', + 'luxembourg': '5a8357eae364cca836b87827b3c6a3d3', + 'netherlands': '3afc61d184aab5c4fd6beaecf2b6c0a9', + 'portugal': '10485b747e1d8c082d33c73d032a7e05', + 'rwanda': '087ce56bbf06b32571ef27ff67bac43b', + 'slovakia': 'f66a0294491086d4c49dc4a804446e50', + 'slovenia': '6fa3ae3920bcc2c890a0d74435d9d29b', + 'south_africa': 'b7f1412d69922e8551cf91081401ec8d', + 'spain': '908bbf29597077c2c6954c439fe8265f', + 'sweden': '4b07726c421981bb2019e8900023393e', + 'vietnam': '32e1cacebcb2da656d40ab8522eb6737', + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + target: str = '2-class', + countries: str | Sequence[str] = ['austria'], + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Fields Of The World dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + target: one of "2-class", "3-class", or "instance" specifying which kind of + target mask to load + countries: which set of countries to load data from + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If any arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.splits + assert target in self.targets + if isinstance(countries, str): + countries = [countries] + assert set(countries) <= set(self.valid_countries) + + self.root = root + self.split = split + self.target = target + self.countries = countries + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.files = self._load_files() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + image and mask at that index with image of dimension 3x1024x1024 + and mask of dimension 1024x1024 + """ + win_a_fn = self.files[index]['win_a'] + win_b_fn = self.files[index]['win_b'] + mask_fn = self.files[index]['mask'] + + win_a = self._load_image(win_a_fn) + win_b = self._load_image(win_b_fn) + mask = self._load_target(mask_fn) + + image = torch.cat((win_a, win_b), dim=0) + sample = {'image': image, 'mask': mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of datapoints in the dataset. + + Returns: + length of dataset + """ + return len(self.files) + + def _load_files(self) -> list[dict[str, str]]: + """Return the paths of the files in the dataset. + + Returns: + a dictionary with "win_a", "win_b", and "mask" keys containing lists of + file paths + """ + files = [] + for country in self.countries: + df = pd.read_parquet( + os.path.join(self.root, country, f'chips_{country}.parquet') + ) + aois = df[df['split'] == self.split]['aoi_id'].values + + for aoi in aois: + if self.target == 'instance': + subdir = 'instance' + elif self.target == '2-class': + subdir = 'semantic_2class' + elif self.target == '3-class': + subdir = 'semantic_3class' + + win_a_fn = os.path.join( + self.root, country, 's2_images', 'window_a', f'{aoi}.tif' + ) + win_b_fn = os.path.join( + self.root, country, 's2_images', 'window_b', f'{aoi}.tif' + ) + + # there are 333 AOIs that are missing imagery across the dataset + if not (os.path.exists(win_a_fn) and os.path.exists(win_b_fn)): + continue + + sample = { + 'win_a': win_a_fn, + 'win_b': win_b_fn, + 'mask': os.path.join( + self.root, country, 'label_masks', subdir, f'{aoi}.tif' + ), + } + files.append(sample) + + return files + + def _load_image(self, path: Path) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the loaded image + """ + filename = os.path.join(path) + with rasterio.open(filename) as f: + array: np.typing.NDArray[np.int_] = f.read() + tensor = array_to_tensor(array).float() + return tensor + + def _load_target(self, path: Path) -> Tensor: + """Load a single mask corresponding to image. + + Args: + path: path to the mask + + Returns: + the mask of the image + """ + filename = os.path.join(path) + with rasterio.open(filename) as f: + array: np.typing.NDArray[np.int_] = f.read(1) + tensor = torch.from_numpy(array).long() + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + for country in self.countries: + if self._verify_data(country): + continue + + filename = f'{country}.zip' + pathname = os.path.join(self.root, filename) + if os.path.exists(pathname): + extract_archive(pathname, os.path.join(self.root, country)) + continue + + if not self.download: + raise DatasetNotFoundError(self) + + download_and_extract_archive( + self.base_url + filename, + os.path.join(self.root, country), + filename=filename, + md5=self.country_to_md5[country] if self.checksum else None, + ) + + def _verify_data(self, country: str) -> bool: + """Verify that data for a country is extracted. + + Args: + country: the country to check + + Returns: + True if the dataset directories and split files are found, else False + """ + for entry in ['label_masks', 's2_images', f'chips_{country}.parquet']: + if not os.path.exists(os.path.join(self.root, country, entry)): + return False + + return True + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5)) + + win_a = einops.rearrange(sample['image'][0:3], 'c h w -> h w c') + win_b = einops.rearrange(sample['image'][4:7], 'c h w -> h w c') + mask = sample['mask'] + + win_a = torch.clip(win_a / 3000, 0, 1) + win_b = torch.clip(win_b / 3000, 0, 1) + + axs[0].imshow(win_a) + axs[0].set_title('Window A') + axs[1].imshow(win_b) + axs[1].set_title('Window B') + if self.target == 'instance': + unique_vals = sorted(np.unique(mask)) + for i, val in enumerate(unique_vals): + mask[mask == val] = i + bg_mask = mask == 0 + mask = (mask % 9) + 1 + mask[bg_mask] = 0 + axs[2].imshow(mask, vmin=0, vmax=10, cmap='tab10', interpolation='none') + axs[2].set_title('Instance mask') + elif self.target == '2-class': + axs[2].imshow(mask, vmin=0, vmax=2, cmap='gray', interpolation='none') + axs[2].set_title('2-class mask') + elif self.target == '3-class': + axs[2].imshow(mask, vmin=0, vmax=2, cmap='gray', interpolation='none') + axs[2].set_title('3-class mask') + for ax in axs: + ax.axis('off') + + if not show_titles: + for ax in axs: + ax.set_title('') + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 6e06db82a71..b5b058726b2 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -46,7 +46,7 @@ 'vit_small_patch16_224': vit_small_patch16_224, } -_model_weights = { +_model_weights: dict[str | Callable[..., nn.Module], WeightsEnum] = { dofa_base_patch16_224: DOFABase16_Weights, dofa_large_patch16_224: DOFALarge16_Weights, resnet18: ResNet18_Weights, @@ -109,8 +109,17 @@ def get_weight(name: str) -> WeightsEnum: Returns: The requested weight enum. + + Raises: + ValueError: If *name* is not a valid WeightsEnum. """ - return eval(name) + for weight_name, weight_enum in _model_weights.items(): + if isinstance(weight_name, str): + for sub_weight_enum in weight_enum: + if name == str(sub_weight_enum): + return sub_weight_enum + + raise ValueError(f'{name} is not a valid WeightsEnum') def list_models() -> list[str]: diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index ecc6bc8c767..0bee76aeed1 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -19,6 +19,9 @@ class BaseTask(LightningModule, ABC): .. versionadded:: 0.5 """ + #: Parameters to ignore when saving hyperparameters. + ignore: Sequence[str] | str | None = 'weights' + #: Model to train. model: Any @@ -28,14 +31,14 @@ class BaseTask(LightningModule, ABC): #: Whether the goal is to minimize or maximize the performance metric to monitor. mode = 'min' - def __init__(self, ignore: Sequence[str] | str | None = None) -> None: + def __init__(self) -> None: """Initialize a new BaseTask instance. Args: ignore: Arguments to skip when saving hyperparameters. """ super().__init__() - self.save_hyperparameters(ignore=ignore) + self.save_hyperparameters(ignore=self.ignore) self.configure_models() self.configure_losses() self.configure_metrics() diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 6cd4f50f8c5..a9e91ce653d 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -324,7 +324,7 @@ def __init__( renamed to *model*, *lr*, and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index cc293099519..2e2766419a5 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -73,7 +73,7 @@ class and used with 'ce' loss. *lr* and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index d13d84dcf15..3d970abdae0 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -53,6 +53,7 @@ class ObjectDetectionTask(BaseTask): .. versionadded:: 0.4 """ + ignore = None monitor = 'val_map' mode = 'max' diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 88bb0ffbcf4..d41f2ea581c 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -136,6 +136,7 @@ class MoCoTask(BaseTask): .. versionadded:: 0.5 """ + ignore = ('weights', 'augmentation1', 'augmentation2') monitor = 'train_loss' def __init__( @@ -219,7 +220,7 @@ def __init__( warnings.warn('MoCo v3 does not use a memory bank') self.weights = weights - super().__init__(ignore=['weights', 'augmentation1', 'augmentation2']) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 86c3423c656..0381316050b 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -77,7 +77,7 @@ def __init__( *lr* and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 9eccda2f0f8..f8e519fa493 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -88,7 +88,7 @@ class and used with 'ce' loss. The *ignore_index* parameter now works for jaccard loss. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model. diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index e6764f141b7..1cb05315f60 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -68,6 +68,7 @@ class SimCLRTask(BaseTask): .. versionadded:: 0.5 """ + ignore = ('weights', 'augmentations') monitor = 'train_loss' def __init__( @@ -140,7 +141,7 @@ def __init__( warnings.warn('SimCLR v2 uses a memory bank') self.weights = weights - super().__init__(ignore=['weights', 'augmentations']) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) self.augmentations = augmentations or simclr_augmentations(