Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading ViT weights into SatMAE for finetuning using SatMAE #15

Open
imantha-das opened this issue Feb 7, 2025 · 0 comments
Open

Loading ViT weights into SatMAE for finetuning using SatMAE #15

imantha-das opened this issue Feb 7, 2025 · 0 comments

Comments

@imantha-das
Copy link

Hi
This is not so much of an issue but rather a question.
I want use the MaskedAutoencoderGroupChannelViT, particularly the mae_vit_base_patch16_dec512d8b to finetune on some regional sentinel2A data. The issue I am having is the weights provided by the repository ViT-Base (200 epochs) is for the GroupChannelsVisionTransformer or in this instance vit_base_patch16 model.

I am speculating here, but what I feel is the ssl training was conducted using mae_vit_base_patch16 model but only the backbone (which I am assuming is vit_base_patch16) weights were saved, which is common.
In usual circumstances you would only need this backbone for the downstream task.

But in my case I want to use this mae_vit_base_patch16_dec512d8b model for further pretraining (or finetuning) on regional dataset. So I am wondering if there is a way to load these pretrain weights of the backbone to the mae_vit_base_patch16_dec512d8b model, which the can be used for further pretraining on a region specific smaller dataset.

I am trying to achive something down the lines of,

# Load satmae model
sat_mae = mae_vit_base_patch16_dec512d8b(args)

# Get backbone
vit_backbone = sat_mae.backbone

# Load backbone weights
checkpoint = torch.load(".../weights.ckpt")
checkpoint_model = checkpoint["model"]
vit_backbone.load_state_dict(checkpoint_model, strict=False)

# Resign p;retrained backbone back to satmae
sat_mae.backbone = vit_backbone

# Then use satmae for further pretraining

Any Ideas how I can achieve this ?
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant