diff --git a/docs/api/landsat_pretrained_weights.csv b/docs/api/landsat_pretrained_weights.csv index 8a360355b25..b62fd9f8be4 100644 --- a/docs/api/landsat_pretrained_weights.csv +++ b/docs/api/landsat_pretrained_weights.csv @@ -29,3 +29,4 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",63.65,46.68,60.01,43.17 ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",66.81,50.16,64.17,47.24 ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,"CC0-1.0",65.04,48.20,62.61,45.46 +Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,'link `__,`link `__,"ODC-BY",,,, diff --git a/docs/api/naip_pretrained_weights.csv b/docs/api/naip_pretrained_weights.csv index fa4acf57f1a..e8e8ef14b8b 100644 --- a/docs/api/naip_pretrained_weights.csv +++ b/docs/api/naip_pretrained_weights.csv @@ -1,2 +1,2 @@ Weight,Channels,Source,Citation,License -Swin_V2_B_Weights.NAIP_RGB_SATLAS,3,`link `__,`link `__,"Apache-2.0" \ No newline at end of file +Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY" diff --git a/docs/api/sentinel1_pretrained_weights.csv b/docs/api/sentinel1_pretrained_weights.csv index 3b34c589119..05d623ccb10 100644 --- a/docs/api/sentinel1_pretrained_weights.csv +++ b/docs/api/sentinel1_pretrained_weights.csv @@ -1,2 +1,3 @@ Weight,Channels,Source,Citation,License ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,"CC-BY-4.0" +Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link `__,`link `__,"ODC-BY" diff --git a/docs/api/sentinel2_pretrained_weights.csv b/docs/api/sentinel2_pretrained_weights.csv index aca56ebe8cb..48869d42a30 100644 --- a/docs/api/sentinel2_pretrained_weights.csv +++ b/docs/api/sentinel2_pretrained_weights.csv @@ -8,4 +8,5 @@ ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"Apache-2.0",87.81,,, ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.5,99.0,62.2, ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",89.9,98.6,61.6, -Swin_V2_B_Weights.SENTINEL2_RGB_SATLAS,3,`link `__,`link `__,"Apache-2.0",,,, \ No newline at end of file +Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY",,,, +Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link `__,`link `__,"ODC-BY",,,, diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index a1e409b7dcb..91b944c8f7a 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -8,6 +8,7 @@ import kornia.augmentation as K import torch import torchvision +from kornia.contrib import Lambda from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum @@ -15,12 +16,31 @@ __all__ = ["Swin_V2_B_Weights"] - # https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 -# All Satlas imagery is uint8 and normalized to the range (0, 1) by dividing by 255 +# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). +# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 +# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501 _satlas_transforms = AugmentationSequential( - K.CenterCrop(256), - K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=["image"] +) + +# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). +# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 +# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501 +_std = torch.tensor( + [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] +) # noqa: E501 +_mean = torch.zeros_like(_std) +_sentinel2_ms_satlas_transforms = AugmentationSequential( + K.Normalize(mean=_mean, std=_std), + Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), + data_keys=["image"], +) + +# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 +_landsat_satlas_transforms = AugmentationSequential( + K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), + Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), data_keys=["image"], ) @@ -39,8 +59,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] .. versionadded:: 0.6 """ - NAIP_RGB_SATLAS = Weights( - url="https://huggingface.co/torchgeo/swin_v2_b_naip_rgb_satlas/resolve/main/swin_v2_b_naip_rgb_satlas-685f45bd.pth", # noqa: E501 + NAIP_RGB_SI_SATLAS = Weights( + url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth", # noqa: E501 transforms=_satlas_transforms, meta={ "dataset": "Satlas", @@ -51,8 +71,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] }, ) - SENTINEL2_RGB_SATLAS = Weights( - url="https://huggingface.co/torchgeo/swin_v2_b_sentinel2_rgb_satlas/resolve/main/swin_v2_b_sentinel2_rgb_satlas-51471041.pth", # noqa: E501 + SENTINEL2_RGB_SI_SATLAS = Weights( + url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_rgb.pth", # noqa: E501 transforms=_satlas_transforms, meta={ "dataset": "Satlas", @@ -63,6 +83,57 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] }, ) + SENTINEL2_MS_SI_SATLAS = Weights( + url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_ms.pth", # noqa: E501 + transforms=_sentinel2_ms_satlas_transforms, + meta={ + "dataset": "Satlas", + "in_chans": 9, + "model": "swin_v2_b", + "publication": "https://arxiv.org/abs/2211.15660", + "repo": "https://github.com/allenai/satlas", + "bands": ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"], + }, + ) + + SENTINEL1_SI_SATLAS = Weights( + url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_si.pth", # noqa: E501 + transforms=_satlas_transforms, + meta={ + "dataset": "Satlas", + "in_chans": 2, + "model": "swin_v2_b", + "publication": "https://arxiv.org/abs/2211.15660", + "repo": "https://github.com/allenai/satlas", + "bands": ["VH", "VV"], + }, + ) + + LANDSAT_SI_SATLAS = Weights( + url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_si.pth", # noqa: E501 + transforms=_landsat_satlas_transforms, + meta={ + "dataset": "Satlas", + "in_chans": 11, + "model": "swin_v2_b", + "publication": "https://arxiv.org/abs/2211.15660", + "repo": "https://github.com/allenai/satlas", + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B09", + "B10", + "B11", + ], # noqa: E501 + }, + ) + def swin_v2_b( weights: Optional[Swin_V2_B_Weights] = None, *args: Any, **kwargs: Any diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index c3c3d54f56b..19239c26bf4 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,7 +8,7 @@ import kornia.augmentation as K import torch from einops import rearrange -from kornia.contrib import extract_tensor_patches +from kornia.contrib import Lambda, extract_tensor_patches from kornia.geometry import crop_by_indices from kornia.geometry.boxes import Boxes from torch import Tensor @@ -25,7 +25,7 @@ class AugmentationSequential(Module): def __init__( self, - *args: Union[K.base._AugmentationBase, K.ImageSequential], + *args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda], data_keys: list[str], **kwargs: Any, ) -> None: @@ -53,7 +53,7 @@ def __init__( else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) + self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501 def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Perform augmentations and update data dict.