Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where is DWConv used in the LightM-UNet model? #6

Open
HashmatShadab opened this issue Mar 18, 2024 · 23 comments
Open

Where is DWConv used in the LightM-UNet model? #6

HashmatShadab opened this issue Mar 18, 2024 · 23 comments

Comments

@HashmatShadab
Copy link

def get_dwconv_layer(
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False
):
depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
return torch.nn.Sequential(depth_conv, point_conv)

Can you please point out where is the DWConv used in constructing the model. I wasn't able to find where this is used in the model.

@HashmatShadab
Copy link
Author

Running the model with
model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) gives around 6 million parameters

@eclipse0922
Copy link

eclipse0922 commented Mar 19, 2024

ResUpBlock uses get_dwconv_layer but LightMUNet does not use ResUpBlock in the code.
Currently, LightMUNet uses ResBlock.

@HashmatShadab
Copy link
Author

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

@eclipse0922
Copy link

Yes, it looks like the author probably uploaded the code used in the test by mistake.

@HashmatShadab
Copy link
Author

Response from the author regarding this issue will be appreciated.

@MrBlankness
Copy link
Owner

def get_dwconv_layer(
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False
):
depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
return torch.nn.Sequential(depth_conv, point_conv)

Can you please point out where is the DWConv used in constructing the model. I wasn't able to find where this is used in the model.

Thank you very much for your attention to our work. We apologize for inadvertently uploading the code related to our ablation experiments in our previous submission, such as replacing DWConv with Conv, and ResUpBlock with ResBlock. We have now updated our code. We greatly appreciate the questions raised by eclipse0922 and HashmatShadab. We will continue to address any issues with LightM-UNet and strive to improve and update our work.

@HashmatShadab
Copy link
Author

So when will the correct code be uploaded?

@MrBlankness
Copy link
Owner

So when will the correct code be uploaded?

We have updated our code.

@MrBlankness
Copy link
Owner

Running the model with model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) gives around 6 million parameters

Allow us to kindly remind you. The number of parameters and computations in the network vary with the settings of network hyperparameters. For example, altering the out_channels of the model will change the number of convolutional kernels in the final layer, thus affecting the number of parameters.

@HashmatShadab
Copy link
Author

model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], )
are the above arguments correct for loading the 3d model discussed in the paper?

@MrBlankness
Copy link
Owner

model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) are the above arguments correct for loading the 3d model discussed in the paper?

No, the relevant parameter settings for the LiTS dataset should be as follows:
model = LightMUNet(spatial_dims=3, init_filters=32, in_channels=1, out_channels=3, blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], )

@HashmatShadab
Copy link
Author

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

Also please respond to this issue as well

@MrBlankness
Copy link
Owner

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

Also please respond to this issue as well

Apologies, based solely on the current information provided, I'm unable to analyze the reason. If possible, please provide more information to help me understand your issue better.

@HashmatShadab
Copy link
Author

HashmatShadab commented Mar 29, 2024

model = LightMUNet(
    spatial_dims=3,
    init_filters=32,
    in_channels=input_channels,
    out_channels=num_classes,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
)

Using the the model class from your updated code i am getting 2.9 M parameters.

@MrBlankness
Copy link
Owner

how many input_channels and num_classes you have set?

@HashmatShadab
Copy link
Author

Input channels: 1
Output channels: 14

@MrBlankness
Copy link
Owner

Input channels: 1 Output channels: 14

We're sorry, but we are unable to reproduce your issue. 🤦‍

@HashmatShadab
Copy link
Author

Can you please provide a short script for this then? How many total parameters are you getting?

@MrBlankness
Copy link
Owner

MrBlankness commented Mar 29, 2024

from thop import profile
model = LightMUNet(
    spatial_dims = 3,
    init_filters = 32,
    in_channels=1,
    out_channels=14,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
).cuda()

data = torch.rand(1, 1, 256, 256, 256).cuda()

_, params = profile(model, inputs=(data, ))
print(params/1e6)

I received a total parameters of 0.4667M.

@HashmatShadab
Copy link
Author

`

from thop import profile

model = LightMUNet(
    spatial_dims = 3,
    init_filters = 32,
    in_channels=1,
    out_channels=14,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
).cuda()

data = torch.rand(1, 1, 256, 256, 256).cuda()

model_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Model Parameters = {model_total_params:,}\n")

_, params = profile(model, inputs=(data, ))
print(params)

`

Total Model Parameters = 2,997,821

[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool3d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
WARNING:root:mode trilinear is not implemented yet, take it a zero op
WARNING:root:mode trilinear is not implemented yet, take it a zero op
WARNING:root:mode trilinear is not implemented yet, take it a zero op
466729.0

@MrBlankness
Copy link
Owner

Perhaps the principles behind the two methods of parameter counting are different? Although I'm not sure of the reason.

@HashmatShadab
Copy link
Author

Haven't used thop package before, so I am also a bit confused. Using from torchinfo import summary to calculate parameters also gives 2.9M parameters

@MrBlankness
Copy link
Owner

Anyway, thank you very much for your feedback, and we will follow up on this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants