-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for ConvNeXt backbone / timm v 0.5.4 #562
Comments
Hi, you could try to update timm after the smp installation and check if it still works. |
I will definitely update it later |
Hi, upgrading timm after the smp installation actually seems to work, thanks for the idea! Unfortunately creating a UNet model with a ConvNeXt backbone gives an output that has half the width and height of the input: import segmentation_models_pytorch as smp
import timm
import torch
assert timm.__version__ == '0.5.4'
# need to use encoder_depth=4, because convnext_tiny isn't that large
model = smp.Unet('tu-convnext_tiny', classes=11, activation='softmax2d', encoder_depth=4, decoder_channels=(128,64,32,16))
dummy_input = torch.rand(1, 3, 224, 224)
output = model(dummy_input)
print(output.shape)
# gives: torch.Size([1, 11, 112, 112]) Do you have an idea what I'm doing wrong here? |
Does it have the same behaviour with encoder_depth=5? |
With
|
The problem seems to be rooted in the ConvNeXt architecture. The backbone starts with a convolution with a kernel size of 4 and a stride of 4, which is not suited for UNet (which is upsampling by a factor of 2 only). |
Unknown model (convnext_tiny) |
|
It seems that one solution would be to make the strides configurable, so that the unusual stride choice in the first convolution of ConvNext could be handled. |
PyTorch image models supports ConvNeXt since version 0.5.4, which could be an interesting backbone for segmentation. Right now segmentation_models.pytorch is still using
timm==0.4.12
. Would it be possible to switch to the newer version?The text was updated successfully, but these errors were encountered: