Skip to content

Commit

Permalink
update data transform for loading resnet50 sentinel 1 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyi111 committed Aug 19, 2024
1 parent 8da11b3 commit 8d0b3c1
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
data_keys=None,
)

_mean_s1 = torch.tensor([-12.59, -20.26])
_std_s1 = torch.tensor([5.26, 5.91])
_zhu_xlab_transforms_s1 = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_mean_s1, std=_std_s1),
data_keys=None,
)

# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
_min = torch.tensor([3, 2, 0])
Expand Down Expand Up @@ -393,7 +402,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]

SENTINEL1_ALL_DECUR = Weights(
url='https://huggingface.co/torchgeo/decur/resolve/9328eeb90c686a88b30f8526ed757b4bc0f12027/rn50_ssl4eo-s12_sar_decur_ep100-f0e69ba2.pth',
transforms=_zhu_xlab_transforms,
transforms=_zhu_xlab_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
Expand All @@ -406,7 +415,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]

SENTINEL1_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
transforms=_zhu_xlab_transforms,
transforms=_zhu_xlab_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
Expand Down

0 comments on commit 8d0b3c1

Please sign in to comment.