diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index e97cf059161..f17b5f19b23 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -45,6 +45,11 @@ def test_resnet(self) -> None: def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet18(weights=mocked_weights) + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta["in_chans"] + sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)} + mocked_weights.transforms(sample) + @pytest.mark.slow def test_resnet_download(self, weights: WeightsEnum) -> None: resnet18(weights=weights) @@ -75,6 +80,11 @@ def test_resnet(self) -> None: def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet50(weights=mocked_weights) + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta["in_chans"] + sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)} + mocked_weights.transforms(sample) + @pytest.mark.slow def test_resnet_download(self, weights: WeightsEnum) -> None: resnet50(weights=weights) diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index 058ecf35f89..88124584488 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -47,6 +47,11 @@ def test_vit(self) -> None: def test_vit_weights(self, mocked_weights: WeightsEnum) -> None: vit_small_patch16_224(weights=mocked_weights) + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta["in_chans"] + sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)} + mocked_weights.transforms(sample) + @pytest.mark.slow def test_vit_download(self, weights: WeightsEnum) -> None: vit_small_patch16_224(weights=weights) diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py new file mode 100644 index 00000000000..529999fbfff --- /dev/null +++ b/torchgeo/datamodules/seco.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Seasonal Contrast datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from einops import repeat + +from ..datasets import SeasonalContrastS2 +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule + + +class SeasonalContrastS2DataModule(NonGeoDataModule): + """LightningDataModule implementation for the Seasonal Contrast dataset. + + .. versionadded:: 0.5 + """ + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new SeasonalContrastS2DataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.SeasonalContrastS2`. + """ + super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs) + + bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands) + seasons = kwargs.get("seasons", 1) + + # Normalization only available for RGB dataset, defined here: + # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 + if bands == SeasonalContrastS2.rgb_bands: + _min = torch.tensor([3, 2, 0]) + _max = torch.tensor([88, 103, 129]) + _mean = torch.tensor([0.485, 0.456, 0.406]) + _std = torch.tensor([0.229, 0.224, 0.225]) + + _min = repeat(_min, "c -> (t c)", t=seasons) + _max = repeat(_max, "c -> (t c)", t=seasons) + _mean = repeat(_mean, "c -> (t c)", t=seasons) + _std = repeat(_std, "c -> (t c)", t=seasons) + + self.aug = AugmentationSequential( + K.Normalize(mean=_min, std=_max - _min), + K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), + K.Normalize(mean=_mean, std=_std), + data_keys=["image"], + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.dataset = SeasonalContrastS2(**self.kwargs) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 0dfd5d848c2..bd785e85f85 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -16,53 +16,28 @@ __all__ = ["ResNet50_Weights", "ResNet18_Weights"] -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = AugmentationSequential( K.Resize(256), K.CenterCrop(224), - K.Normalize(mean=0, std=10000), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)), data_keys=["image"], ) -# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501 +# Normalization only available for RGB dataset, defined here: +# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 +_min = torch.tensor([3, 2, 0]) +_max = torch.tensor([88, 103, 129]) +_mean = torch.tensor([0.485, 0.456, 0.406]) +_std = torch.tensor([0.229, 0.224, 0.225]) _seco_transforms = AugmentationSequential( - K.Resize(128), - K.Normalize( - mean=torch.Tensor( - [ - 340.76769064, - 429.9430203, - 614.21682446, - 590.23569706, - 950.68368468, - 1792.46290469, - 2075.46795189, - 2218.94553375, - 2266.46036911, - 2246.0605464, - 1594.42694882, - 1009.32729131, - ] - ), - std=torch.Tensor( - [ - 554.81258967, - 572.41639287, - 582.87945694, - 675.88746967, - 729.89827633, - 1096.01480586, - 1273.45393088, - 1365.45589904, - 1356.13789355, - 1302.3292881, - 1079.19066363, - 818.86747235, - ] - ), - ), + K.Resize(256), + K.CenterCrop(224), + K.Normalize(mean=_min, std=_max - _min), + K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), + K.Normalize(mean=_mean, std=_std), data_keys=["image"], ) diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index 52fca28181e..7080257852c 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -7,6 +7,7 @@ import kornia.augmentation as K import timm +import torch from timm.models.vision_transformer import VisionTransformer from torchvision.models._api import Weights, WeightsEnum @@ -20,7 +21,7 @@ _zhu_xlab_transforms = AugmentationSequential( K.Resize(256), K.CenterCrop(224), - K.Normalize(mean=0, std=10000), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)), data_keys=["image"], )