-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segmentation Pretrained Weights #1046
Segmentation Pretrained Weights #1046
Conversation
This only supports ImageNet pretrained weights, which is pretty useless for us. We really want support for our weight enums like we have for regression/classification/byol. Could we manually load weights from state dict ourselves like we do with our timm-based models? |
Oh wait, I didn't read the full PR. Maybe this does cover all the features we want. You'll definitely want to update the docstring though. |
Also needs tests |
Maybe also put a note in the docs that not all our pretrained models (i.e. the ViTs) will be compatible according to the smp docs or maybe there is a way around, not sure. |
The other thing I realized is that by default smp will use torchvision for resnet backbones e.g. |
We could just only allow timm pretrained backbones here |
This actually needs #1049 to be solved otherwise this only supports resnet backbones. |
So do we want to prefix |
That is a tough question. I would be fine with only supporting timm backbones for now, which would suggest prepending |
We could prefix "tu-" if it results in a valid timm backbone to give ourselves room to grow. |
da2fb0c
to
86a0716
Compare
6324e42
to
a7695d7
Compare
a7695d7
to
fb488f4
Compare
Ugh, also the horrific model checkpoint tests |
We need this :) |
@isaaccorley does this work with our custom ResNet weights? |
I haven't tried using the weights enums but in theory it should since the encoders are just resnet models. |
It should be the same loading code and docstring description as all of the other trainers. |
@calebrob6 I just made it work with the ResNet weights from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.models import ResNet50_Weights
model = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights=ResNet50_Weights.SENTINEL2_RGB_MOCO,
in_channels=3,
num_classes=2,
loss="ce",
ignore_index=0,
learning_rate=3e-4,
learning_rate_schedule_patience=5,
freeze_backbone=False,
freeze_decoder=False,
) |
This PR addresses part of #1044 adds the ability to load pretrained weights from a backbone model e.g. ResNet into a semantic segmentation encoder. This works for the segmentation-models-pytorch Unet and DeepLabv3 implementations but not the FCN because we aren't using a backbone for that.