diff --git a/.gitignore b/.gitignore index b74875a74d..c6aea2258b 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,7 @@ tests/testing_data/CT_2D_head_moving.mha # profiling results *.prof runs + +*.gz + +*.pth diff --git a/monai/networks/blocks/denseblock.py b/monai/networks/blocks/denseblock.py index ecccab9d5a..8c67584f5f 100644 --- a/monai/networks/blocks/denseblock.py +++ b/monai/networks/blocks/denseblock.py @@ -11,7 +11,7 @@ from __future__ import annotations -from collections.abc import Sequence +from typing import Sequence import torch import torch.nn as nn diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 95ddad7842..a0c8628172 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -17,6 +17,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .daf3d import DAF3D from .densenet import ( DenseNet, Densenet, @@ -51,6 +52,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet from .resnet import ( diff --git a/monai/networks/nets/daf3d.py b/monai/networks/nets/daf3d.py new file mode 100644 index 0000000000..5a83cdc600 --- /dev/null +++ b/monai/networks/nets/daf3d.py @@ -0,0 +1,574 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from monai.networks.blocks import ADN +from monai.networks.blocks.aspp import SimpleASPP +from monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork +from monai.networks.layers.factories import Conv, Norm +from monai.networks.nets.resnet import ResNet, ResNetBottleneck + +__all__ = [ + "AttentionModule", + "Daf3dASPP", + "Daf3dResNetBottleneck", + "Daf3dResNetDilatedBottleneck", + "Daf3dResNet", + "Daf3dBackbone", + "Daf3dFPN", + "Daf3dBackboneWithFPN", + "DAF3D", +] + + +class AttentionModule(nn.Module): + """ + Attention Module as described in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . Returns refined single layer feature (SLF) and attentive map + + Args: + spatial_dims: dimension of inputs. + in_channels: number of input channels (channels of slf and mlf). + out_channels: number of output channels (channels of attentive map and refined slf). + norm: normalization type. + act: activation type. + """ + + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + norm=("group", {"num_groups": 32, "num_channels": 64}), + act="PRELU", + ): + super().__init__() + + self.attentive_map = nn.Sequential( + Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + Convolution( + spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, adn_ordering="A", act="SIGMOID" + ), + ) + self.refine = nn.Sequential( + Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + ) + + def forward(self, slf, mlf): + att = self.attentive_map(torch.cat((slf, mlf), 1)) + out = self.refine(torch.cat((slf, att * mlf), 1)) + return (out, att) + + +class Daf3dASPP(SimpleASPP): + """ + Atrous Spatial Pyramid Pooling module as used in 'Deep Attentive Features for Prostate Segmentation in + 3D Transrectal Ultrasound' . Core functionality as in SimpleASPP, but after each + layerwise convolution a group normalization is added. Further weight initialization for convolutions is provided in + _init_weight(). Additional possibility to specify the number of final output channels. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + in_channels: number of input channels. + conv_out_channels: number of output channels of each atrous conv. + out_channels: number of output channels of final convolution. + If None, uses len(kernel_sizes) * conv_out_channels + kernel_sizes: a sequence of four convolutional kernel sizes. + Defaults to (1, 3, 3, 3) for four (dilated) convolutions. + dilations: a sequence of four convolutional dilation parameters. + Defaults to (1, 2, 4, 6) for four (dilated) convolutions. + norm_type: final kernel-size-one convolution normalization type. + Defaults to batch norm. + acti_type: final kernel-size-one convolution activation type. + Defaults to leaky ReLU. + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + Raises: + ValueError: When ``kernel_sizes`` length differs from ``dilations``. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + conv_out_channels: int, + out_channels: int | None = None, + kernel_sizes: Sequence[int] = (1, 3, 3, 3), + dilations: Sequence[int] = (1, 2, 4, 6), + norm_type: tuple | str | None = "BATCH", + acti_type: tuple | str | None = "LEAKYRELU", + bias: bool = False, + ) -> None: + super().__init__( + spatial_dims, in_channels, conv_out_channels, kernel_sizes, dilations, norm_type, acti_type, bias + ) + + # add normalization after each atrous convolution, initializes weights + new_convs = nn.ModuleList() + for _conv in self.convs: + tmp_conv = Convolution(1, 1, 1) + tmp_conv.conv = _conv + tmp_conv.adn = ADN(ordering="N", norm=norm_type, norm_dim=1) + tmp_conv = self._init_weight(tmp_conv) + new_convs.append(tmp_conv) + self.convs = new_convs + + # change final convolution to different out_channels + if out_channels is None: + out_channels = len(kernel_sizes) * conv_out_channels + + self.conv_k1 = Convolution( + spatial_dims=3, + in_channels=len(kernel_sizes) * conv_out_channels, + out_channels=out_channels, + kernel_size=1, + norm=norm_type, + act=acti_type, + ) + + def _init_weight(self, conv): + for m in conv.modules(): + if isinstance(m, nn.Conv3d): # true for conv.conv + torch.nn.init.kaiming_normal_(m.weight) + return conv + + +class Daf3dResNetBottleneck(ResNetBottleneck): + """ + ResNetBottleneck block as used in 'Deep Attentive Features for Prostate Segmentation in 3D + Transrectal Ultrasound' . + Instead of Batch Norm Group Norm is used, instead of ReLU PReLU activation is used. + Initial expansion is 2 instead of 4 and second convolution uses groups. + + Args: + in_planes: number of input channels. + planes: number of output channels (taking expansion into account). + spatial_dims: number of spatial dimensions of the input image. + stride: stride to use for second conv layer. + downsample: which downsample layer to use. + """ + + expansion = 2 + + def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): + norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + + # in case downsample uses batch norm, change to group norm + if isinstance(downsample, nn.Sequential): + downsample = nn.Sequential( + conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False), + norm_type(num_groups=32, num_channels=planes * self.expansion), + ) + + super().__init__(in_planes, planes, spatial_dims, stride, downsample) + + # change norm from batch to group norm + self.bn1 = norm_type(num_groups=32, num_channels=planes) + self.bn2 = norm_type(num_groups=32, num_channels=planes) + self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion) + + # adapt second convolution to work with groups + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False) + + # adapt activation function + self.relu = nn.PReLU() # type: ignore + + +class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck): + """ + ResNetDilatedBottleneck as used in 'Deep Attentive Features for Prostate Segmentation in 3D + Transrectal Ultrasound' . + Same as Daf3dResNetBottleneck but dilation of 2 is used in second convolution. + Args: + in_planes: number of input channels. + planes: number of output channels (taking expansion into account). + spatial_dims: number of spatial dimensions of the input image. + stride: stride to use for second conv layer. + downsample: which downsample layer to use. + """ + + def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): + super().__init__(in_planes, planes, spatial_dims, stride, downsample) + + # add dilation in second convolution + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + self.conv2 = conv_type( + planes, planes, kernel_size=3, stride=stride, padding=2, dilation=2, groups=32, bias=False + ) + + +class Daf3dResNet(ResNet): + """ + ResNet as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + Uses two Daf3dResNetBottleneck blocks followed by two Daf3dResNetDilatedBottleneck blocks. + + Args: + layers: how many layers to use. + block_inplanes: determine the size of planes at each step. Also tunable with widen_factor. + spatial_dims: number of spatial dimensions of the input image. + n_input_channels: number of input channels for first convolutional layer. + conv1_t_size: size of first convolution layer, determines kernel and padding. + conv1_t_stride: stride of first convolution layer. + no_max_pool: bool argument to determine if to use maxpool layer. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. + widen_factor: widen output for each layer. + num_classes: number of output (classifications). + feed_forward: whether to add the FC layer for the output, default to `True`. + bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. + + """ + + def __init__( + self, + layers: list[int], + block_inplanes: list[int], + spatial_dims: int = 3, + n_input_channels: int = 3, + conv1_t_size: tuple[int] | int = 7, + conv1_t_stride: tuple[int] | int = 1, + no_max_pool: bool = False, + shortcut_type: str = "B", + widen_factor: float = 1.0, + num_classes: int = 400, + feed_forward: bool = True, + bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) + ): + super().__init__( + ResNetBottleneck, + layers, + block_inplanes, + spatial_dims, + n_input_channels, + conv1_t_size, + conv1_t_stride, + no_max_pool, + shortcut_type, + widen_factor, + num_classes, + feed_forward, + bias_downsample, + ) + + self.in_planes = 64 + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + + # adapt first convolution to work with new in_planes + self.conv1 = conv_type( + n_input_channels, self.in_planes, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False + ) + self.bn1 = norm_type(32, 64) + self.relu = nn.PReLU() # type: ignore + + # adapt layers to our needs + self.layer1 = self._make_layer(Daf3dResNetBottleneck, block_inplanes[0], layers[0], spatial_dims, shortcut_type) + self.layer2 = self._make_layer( + Daf3dResNetBottleneck, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=(1, 2, 2) # type: ignore + ) + self.layer3 = self._make_layer( + Daf3dResNetDilatedBottleneck, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=1 + ) + self.layer4 = self._make_layer( + Daf3dResNetDilatedBottleneck, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=1 + ) + + +class Daf3dBackbone(nn.Module): + """ + Backbone for 3D Feature Pyramid Network in DAF3D module based on 'Deep Attentive Features for Prostate Segmentation in + 3D Transrectal Ultrasound' . + + Args: + n_input_channels: number of input channels for the first convolution. + """ + + def __init__(self, n_input_channels): + super().__init__() + net = Daf3dResNet( + layers=[3, 4, 6, 3], + block_inplanes=[128, 256, 512, 1024], + n_input_channels=n_input_channels, + num_classes=2, + bias_downsample=False, + ) + net_modules = list(net.children()) + self.layer0 = nn.Sequential(*net_modules[:3]) + self.layer1 = nn.Sequential(*net_modules[3:5]) + self.layer2 = net_modules[5] + self.layer3 = net_modules[6] + self.layer4 = net_modules[7] + + def forward(self, x): + layer0 = self.layer0(x) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + return layer4 + + +class Daf3dFPN(FeaturePyramidNetwork): + """ + Feature Pyramid Network as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + Omits 3x3x3 convolution of layer_blocks and interpolates resulting feature maps to be the same size as + feature map with highest resolution. + + Args: + spatial_dims: 2D or 3D images + in_channels_list: number of channels for each feature map that is passed to the module + out_channels: number of channels of the FPN representation + extra_blocks: if provided, extra operations will be performed. + It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names + """ + + def __init__( + self, + spatial_dims: int, + in_channels_list: list[int], + out_channels: int, + extra_blocks: ExtraFPNBlock | None = None, + ): + super().__init__(spatial_dims, in_channels_list, out_channels, extra_blocks) + + self.inner_blocks = nn.ModuleList() + for in_channels in in_channels_list: + if in_channels == 0: + raise ValueError("in_channels=0 is currently not supported") + inner_block_module = Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=1, + adn_ordering="NA", + act="PRELU", + norm=("group", {"num_groups": 32, "num_channels": 128}), + ) + self.inner_blocks.append(inner_block_module) + + def forward(self, x: dict[str, Tensor]) -> dict[str, Tensor]: + # unpack OrderedDict into two lists for easier handling + names = list(x.keys()) + x_values: list[Tensor] = list(x.values()) + + last_inner = self.get_result_from_inner_blocks(x_values[-1], -1) + results = [] + results.append(last_inner) + + for idx in range(len(x_values) - 2, -1, -1): + inner_lateral = self.get_result_from_inner_blocks(x_values[idx], idx) + feat_shape = inner_lateral.shape[2:] + inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="trilinear") + last_inner = inner_lateral + inner_top_down + results.insert(0, last_inner) + + if self.extra_blocks is not None: + results, names = self.extra_blocks(results, x_values, names) + + # bring all layers to same size + results = [results[0]] + [F.interpolate(l, size=x["feat1"].size()[2:], mode="trilinear") for l in results[1:]] + # make it back an OrderedDict + out = OrderedDict(list(zip(names, results))) + + return out + + +class Daf3dBackboneWithFPN(BackboneWithFPN): + """ + Same as BackboneWithFPN but uses custom Daf3DFPN as feature pyramid network + + Args: + backbone: backbone network + return_layers: a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + in_channels_list: number of channels for each feature map + that is returned, in the order they are present in the OrderedDict + out_channels: number of channels in the FPN. + spatial_dims: 2D or 3D images + extra_blocks: if provided, extra operations will + be performed. It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names + """ + + def __init__( + self, + backbone: nn.Module, + return_layers: dict[str, str], + in_channels_list: list[int], + out_channels: int, + spatial_dims: int | None = None, + extra_blocks: ExtraFPNBlock | None = None, + ) -> None: + super().__init__(backbone, return_layers, in_channels_list, out_channels, spatial_dims, extra_blocks) + + if spatial_dims is None: + if hasattr(backbone, "spatial_dims") and isinstance(backbone.spatial_dims, int): + spatial_dims = backbone.spatial_dims + elif isinstance(backbone.conv1, nn.Conv2d): + spatial_dims = 2 + elif isinstance(backbone.conv1, nn.Conv3d): + spatial_dims = 3 + else: + raise ValueError( + "Could not determine value of `spatial_dims` from backbone, please provide explicit value." + ) + + self.fpn = Daf3dFPN(spatial_dims, in_channels_list, out_channels, extra_blocks) + + +class DAF3D(nn.Module): + """ + DAF3D network based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + The network consists of a 3D Feature Pyramid Network which is applied on the feature maps of a 3D ResNet, + followed by a custom Attention Module and an ASPP module. + During training the supervised signal consists of the outputs of the FPN (four Single Layer Features, SLFs), + the outputs of the attention module (four Attentive Features) and the final prediction. + They are individually compared to the ground truth, the final loss consists of a weighted sum of all + individual losses (see DAF3D tutorial for details). + There is an additional possiblity to return all supervised signals as well as the Attentive Maps in validation + mode to visualize inner functionality of the network. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + visual_output: whether to return all SLFs, Attentive Maps, Refined SLFs in validation mode + can be used to visualize inner functionality of the network + """ + + def __init__(self, in_channels, out_channels, visual_output=False): + super().__init__() + self.visual_output = visual_output + self.backbone_with_fpn = Daf3dBackboneWithFPN( + backbone=Daf3dBackbone(in_channels), + return_layers={"layer1": "feat1", "layer2": "feat2", "layer3": "feat3", "layer4": "feat4"}, + in_channels_list=[256, 512, 1024, 2048], + out_channels=128, + spatial_dims=3, + ) + self.predict1 = nn.Conv3d(128, out_channels, kernel_size=1) + + group_norm = ("group", {"num_groups": 32, "num_channels": 64}) + act_prelu = ("prelu", {"num_parameters": 1, "init": 0.25}) + self.fuse = nn.Sequential( + Convolution( + spatial_dims=3, + in_channels=512, + out_channels=64, + kernel_size=1, + adn_ordering="NA", + norm=group_norm, + act=act_prelu, + ), + Convolution( + spatial_dims=3, + in_channels=64, + out_channels=64, + kernel_size=3, + adn_ordering="NA", + padding=1, + norm=group_norm, + act=act_prelu, + ), + Convolution( + spatial_dims=3, + in_channels=64, + out_channels=64, + kernel_size=3, + adn_ordering="NA", + padding=1, + norm=group_norm, + act=act_prelu, + ), + ) + self.attention = AttentionModule( + spatial_dims=3, in_channels=192, out_channels=64, norm=group_norm, act=act_prelu + ) + + self.refine = Convolution(3, 256, 64, kernel_size=1, adn_ordering="NA", norm=group_norm, act=act_prelu) + self.predict2 = nn.Conv3d(64, out_channels, kernel_size=1) + self.aspp = Daf3dASPP( + spatial_dims=3, + in_channels=64, + conv_out_channels=64, + out_channels=64, + kernel_sizes=(3, 3, 3, 3), + dilations=((1, 1, 1), (1, 6, 6), (1, 12, 12), (1, 18, 18)), # type: ignore + norm_type=group_norm, + acti_type=None, + bias=True, + ) + + def forward(self, x): + # layers from 1 - 4 + single_layer_features = list(self.backbone_with_fpn(x).values()) + + # first 4 supervised signals (SLFs 1 - 4) + supervised1 = [self.predict1(slf) for slf in single_layer_features] + + mlf = self.fuse(torch.cat(single_layer_features, 1)) + + attentive_features_maps = [self.attention(slf, mlf) for slf in single_layer_features] + att_features, att_maps = tuple(zip(*attentive_features_maps)) + + # second 4 supervised signals (af 1 - 4) + supervised2 = [self.predict2(af) for af in att_features] + + # attentive maps as optional additional output + supervised3 = [self.predict2(am) for am in att_maps] + + attentive_mlf = self.refine(torch.cat(att_features, 1)) + + aspp = self.aspp(attentive_mlf) + + supervised_final = self.predict2(aspp) + + if self.training: + output = supervised1 + supervised2 + [supervised_final] + output = [F.interpolate(o, size=x.size()[2:], mode="trilinear") for o in output] + else: + if self.visual_output: + supervised_final = F.interpolate(supervised_final, size=x.size()[2:], mode="trilinear") + supervised_inner = [ + F.interpolate(o, size=x.size()[2:], mode="trilinear") + for o in supervised1 + supervised2 + supervised3 + ] + output = [supervised_final] + supervised_inner + else: + output = F.interpolate(supervised_final, size=x.size()[2:], mode="trilinear") + return output diff --git a/monai/networks/nets/quicknat.py b/monai/networks/nets/quicknat.py new file mode 100644 index 0000000000..cbcccf24d7 --- /dev/null +++ b/monai/networks/nets/quicknat.py @@ -0,0 +1,439 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ConvDenseBlock, Convolution +from monai.networks.blocks import squeeze_and_excitation as se +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.simplelayers import SkipConnection +from monai.networks.layers.utils import get_dropout_layer, get_pool_layer +from monai.utils import optional_import + +# Lazy import to avoid dependency +se1, flag = optional_import("squeeze_and_excitation") + +__all__ = ["Quicknat"] + +# QuickNAT specific Blocks + + +class SkipConnectionWithIdx(SkipConnection): + """ + Combine the forward pass input with the result from the given submodule:: + --+--submodule--o-- + |_____________| + The available modes are ``"cat"``, ``"add"``, ``"mul"``. + Defaults to "cat" and dimension 1. + Inherits from SkipConnection but provides the indizes with each forward pass. + """ + + def forward(self, input, indices): + return super().forward(input), indices + + +class SequentialWithIdx(nn.Sequential): + """ + A sequential container. + Modules will be added to it in the order they are passed in the + constructor. + Own implementation to work with the new indices in the forward pass. + """ + + def __init__(self, *args): + super().__init__(*args) + + def forward(self, input, indices): + for module in self: + input, indices = module(input, indices) + return input, indices + + +class ClassifierBlock(Convolution): + """ + Returns a classifier block without an activation function at the top. + It consists of a 1 * 1 convolutional layer which maps the input to a num_class channel feature map. + The output is a probability map for each of the classes. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of classes to map to. + strides: convolution stride. Defaults to 1. + kernel_size: convolution kernel size. Defaults to 3. + adn_ordering: a string representing the ordering of activation, normalization, and dropout. + Defaults to "NDA". + act: activation type and arguments. Defaults to PReLU. + + """ + + def __init__(self, spatial_dims, in_channels, out_channels, strides, kernel_size, act=None, adn_ordering="A"): + super().__init__(spatial_dims, in_channels, out_channels, strides, kernel_size, adn_ordering, act) + + def forward(self, input: torch.Tensor, weights=None, indices=None): + _, channel, *dims = input.size() + if weights is not None: + weights, _ = torch.max(weights, dim=0) + weights = weights.view(1, channel, 1, 1) + # use weights to adapt how the classes are weighted. + if len(dims) == 2: + out_conv = F.conv2d(input, weights) + else: + raise ValueError("Quicknat is a 2D architecture, please check your dimension.") + else: + out_conv = super().forward(input) + # no indices to return + return out_conv, None + + +# Quicknat specific blocks. All blocks inherit from MONAI blocks but have adaptions to their structure +class ConvConcatDenseBlock(ConvDenseBlock): + """ + This dense block is defined as a sequence of 'Convolution' blocks. It overwrite the '_get_layer' methodto change the ordering of + Every convolutional layer is preceded by a batch-normalization layer and a Rectifier Linear Unit (ReLU) layer. + The first two convolutional layers are followed by a concatenation layer that concatenates + the input feature map with outputs of the current and previous convolutional blocks. + Kernel size of two convolutional layers kept small to limit number of paramters. + Appropriate padding is provided so that the size of feature maps before and after convolution remains constant. + The output channels for each convolution layer is set to 64, which acts as a bottle- neck for feature map selectivity. + The input channel size is variable, depending on the number of dense connections. + The third convolutional layer is also preceded by a batch normalization and ReLU, + but has a 1 * 1 kernel size to compress the feature map size to 64. + Args: + in_channles: variable depending on depth of the network + seLayer: Squeeze and Excite block to be included, defaults to None, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, + dropout_layer: Dropout block to be included, defaults to None. + :return: forward passed tensor + """ + + def __init__( + self, + in_channels: int, + se_layer: Optional[nn.Module] = None, + dropout_layer: Optional[nn.Dropout2d] = None, + kernel_size: Sequence[int] | int = 5, + num_filters: int = 64, + ): + self.count = 0 + super().__init__( + in_channels=in_channels, + spatial_dims=2, + # number of channels stay constant throughout the convolution layers + channels=[num_filters, num_filters, num_filters], + norm=("instance", {"num_features": in_channels}), + kernel_size=kernel_size, + ) + self.se_layer = se_layer if se_layer is not None else nn.Identity() + self.dropout_layer = dropout_layer if dropout_layer is not None else nn.Identity() + + def _get_layer(self, in_channels, out_channels, dilation): + """ + After ever convolutional layer the output is concatenated with the input and the layer before. + The concatenated output is used as input to the next convolutional layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + strides: convolution stride. + is_top: True if this is the top block. + """ + kernelsize = self.kernel_size if self.count < 2 else (1, 1) + # padding = None if self.count < 2 else (0, 0) + self.count += 1 + conv = Convolution( + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=1, + kernel_size=kernelsize, + act=self.act, + norm=("instance", {"num_features": in_channels}), + ) + return nn.Sequential(conv.get_submodule("adn"), conv.get_submodule("conv")) + + def forward(self, input, _): + i = 0 + result = input + for l in self.children(): + # ignoring the max (un-)pool and droupout already added in the initial initialization step + if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)): + continue + # first convolutional forward + result = l(result) + if i == 0: + result1 = result + # concatenation with the input feature map + result = torch.cat((input, result), dim=1) + + if i == 1: + # concatenation with input feature map and feature map from first convolution + result = torch.cat((result1, result, input), dim=1) + i = i + 1 + + # if SELayer or Dropout layer defined put output through layer before returning, + # else it just goes through nn.Identity and the output does not change + result = self.se_layer(result) + result = self.dropout_layer(result) + + return result, None + + +class Encoder(ConvConcatDenseBlock): + """ + Returns a convolution dense block for the encoding (down) part of a layer of the network. + This Encoder block downpools the data with max_pool. + Its output is used as input to the next layer down. + New feature: it returns the indices of the max_pool to the decoder (up) path + at the same layer to upsample the input. + + Args: + in_channels: number of input channels. + max_pool: predefined max_pool layer to downsample the data. + se_layer: Squeeze and Excite block to be included, defaults to None. + dropout: Dropout block to be included, defaults to None. + kernel_size : kernel size of the convolutional layers. Defaults to 5*5 + num_filters : number of input channels to each convolution block. Defaults to 64 + """ + + def __init__(self, in_channels: int, max_pool, se_layer, dropout, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.max_pool = max_pool + + def forward(self, input, indices=None): + input, indices = self.max_pool(input) + + out_block, _ = super().forward(input, None) + # safe the indices for unpool on decoder side + return out_block, indices + + +class Decoder(ConvConcatDenseBlock): + """ + Returns a convolution dense block for the decoding (up) part of a layer of the network. + This will upsample data with an unpool block before the forward. + It uses the indices from corresponding encoder on it's level. + Its output is used as input to the next layer up. + + Args: + in_channels: number of input channels. + un_pool: predefined unpool block. + se_layer: predefined SELayer. Defaults to None. + dropout: predefined dropout block. Defaults to None. + kernel_size: Kernel size of convolution layers. Defaults to 5*5. + num_filters: number of input channels to each convolution layer. Defaults to 64. + """ + + def __init__(self, in_channels: int, un_pool, se_layer, dropout, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.un_pool = un_pool + + def forward(self, input, indices): + out_block, _ = super().forward(input, None) + out_block = self.un_pool(out_block, indices) + return out_block, None + + +class Bottleneck(ConvConcatDenseBlock): + """ + Returns the bottom or bottleneck layer at the bottom of a network linking encoder to decoder halves. + It consists of a 5 * 5 convolutional layer and a batch normalization layer to separate + the encoder and decoder part of the network, restricting information flow between the encoder and decoder. + + Args: + in_channels: number of input channels. + se_layer: predefined SELayer. Defaults to None. + dropout: predefined dropout block. Defaults to None. + un_pool: predefined unpool block. + max_pool: predefined maxpool block. + kernel_size: Kernel size of convolution layers. Defaults to 5*5. + num_filters: number of input channels to each convolution layer. Defaults to 64. + """ + + def __init__(self, in_channels: int, se_layer, dropout, max_pool, un_pool, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.max_pool = max_pool + self.un_pool = un_pool + + def forward(self, input, indices): + out_block, indices = self.max_pool(input) + out_block, _ = super().forward(out_block, None) + out_block = self.un_pool(out_block, indices) + return out_block, None + + +class Quicknat(nn.Module): + """ + Model for "Quick segmentation of NeuroAnaTomy (QuickNAT) based on a deep fully convolutional neural network. + Refer to: "QuickNAT: A Fully Convolutional Network for Quick and Accurate Segmentation of Neuroanatomy by + Abhijit Guha Roya, Sailesh Conjetib, Nassir Navabb, Christian Wachingera" + + QuickNAT has an encoder/decoder like 2D F-CNN architecture with 4 encoders and 4 decoders separated by a bottleneck layer. + The final layer is a classifier block with softmax. + The architecture includes skip connections between all encoder and decoder blocks of the same spatial resolution, + similar to the U-Net architecture. + All Encoder and Decoder consist of three convolutional layers all with a Batch Normalization and ReLU. + The first two convolutional layers are followed by a concatenation layer that concatenates + the input feature map with outputs of the current and previous convolutional blocks. + The kernel size of the first two convolutional layers is 5*5, the third convolutional layer has a kernel size of 1*1. + + Data in the encode path is downsampled using max pooling layers instead of upsamling like UNet and in the decode path + upsampled using max un-pooling layers instead of transpose convolutions. + The pooling is done at the beginning of the block and the unpool afterwards. + The indices of the max pooling in the Encoder are forwarded through the layer to be available to the corresponding Decoder. + + The bottleneck block consists of a 5 * 5 convolutional layer and a batch normalization layer + to separate the encoder and decoder part of the network, + restricting information flow between the encoder and decoder. + + The output feature map from the last decoder block is passed to the classifier block, + which is a convolutional layer with 1 * 1 kernel size that maps the input to an N channel feature map, + where N is the number of segmentation classes. + + To further explain this consider the first example network given below. This network has 3 layers with strides + of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input + data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of + the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its + input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this + ensures the final output of the network has the same shape as the input. + + The original QuickNAT implementation included a `enable_test_dropout()` mechanism for uncertainty estimation during + testing. As the dropout layers are the only stochastic components of this network calling the train() method instead + of eval() in testing or inference has the same effect. + + Args: + num_classes: number of classes to segmentate (output channels). + num_channels: number of input channels. + num_filters: number of output channels for each convolutional layer in a Dense Block. + kernel_size: size of the kernel of each convolutional layer in a Dense Block. + kernel_c: convolution kernel size of classifier block kernel. + stride_convolution: convolution stride. Defaults to 1. + pool: kernel size of the pooling layer, + stride_pool: stride for the pooling layer. + se_block: Squeeze and Excite block type to be included, defaults to None. Valid options : NONE, CSE, SSE, CSSE, + droup_out: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). + Defaults to "NA". See also: :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + from monai.networks.nets import QuickNAT + + # network with max pooling by a factor of 2 at each layer with no se_block. + net = QuickNAT( + num_classes=3, + num_channels=1, + num_filters=64, + pool = 2, + se_block = "None" + ) + + """ + + def __init__( + self, + num_classes: int = 33, + num_channels: int = 1, + num_filters: int = 64, + kernel_size: Sequence[int] | int = 5, + kernel_c: int = 1, + stride_conv: int = 1, + pool: int = 2, + stride_pool: int = 2, + # Valid options : NONE, CSE, SSE, CSSE + se_block: str = "None", + drop_out: float = 0, + act: Union[Tuple, str] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, + adn_ordering: str = "NA", + ) -> None: + self.act = act + self.norm = norm + self.adn_ordering = adn_ordering + super().__init__() + se_layer = self.get_selayer(num_filters, se_block) + dropout_layer = get_dropout_layer(name=("dropout", {"p": drop_out}), dropout_dim=2) + max_pool = get_pool_layer( + name=("max", {"kernel_size": pool, "stride": stride_pool, "return_indices": True, "ceil_mode": True}), + spatial_dims=2, + ) + # for the unpooling layer there is currently no Monai implementation available, return to torch implementation + un_pool = nn.MaxUnpool2d(kernel_size=pool, stride=stride_pool) + + # sequence of convolutional strides (like in UNet) not needed as they are always stride_conv. This defaults to 1. + def _create_model(layer: int) -> nn.Module: + """ + Builds the QuickNAT structure from the bottom up by recursing down to the bottelneck layer, then creating sequential + blocks containing the decoder, a skip connection around the previous block, and the encoder. + At the last layer a classifier block is added to the Sequential. + + Args: + layer = inversproportional to the layers left to create + """ + subblock: nn.Module + if layer < 4: + subblock = _create_model(layer + 1) + + else: + subblock = Bottleneck(num_filters, se_layer, dropout_layer, max_pool, un_pool, kernel_size, num_filters) + + if layer == 1: + down = ConvConcatDenseBlock(num_channels, se_layer, dropout_layer, kernel_size, num_filters) + up = ConvConcatDenseBlock(num_filters * 2, se_layer, dropout_layer, kernel_size, num_filters) + classifier = ClassifierBlock(2, num_filters, num_classes, stride_conv, kernel_c) + return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up, classifier) + else: + up = Decoder(num_filters * 2, un_pool, se_layer, dropout_layer, kernel_size, num_filters) + down = Encoder(num_filters, max_pool, se_layer, dropout_layer, kernel_size, num_filters) + return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up) + + self.model = _create_model(1) + + def get_selayer(self, n_filters, se_block_type="None"): + """ + Returns the SEBlock defined in the initialization of the QuickNAT model. + + Args: + n_filters: encoding half of the layer + se_block_type: defaults to None. Valid options are None, CSE, SSE, CSSE + Returns: Appropriate SEBlock. SSE and CSSE not implemented in Monai yet. + """ + if se_block_type == "CSE": + return se.ChannelSELayer(2, n_filters) + # not implemented in squeeze_and_excitation in monai use other squeeze_and_excitation import: + elif se_block_type == "SSE" or se_block_type == "CSSE": + # Throw error if squeeze_and_excitation is not installed + if not flag: + raise ImportError("Please install squeeze_and_excitation locally to use SpatialSELayer") + if se_block_type == "SSE": + return se1.SpatialSELayer(n_filters) + else: + return se1.ChannelSpatialSELayer(n_filters) + else: + return None + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input, _ = self.model(input, None) + return input diff --git a/tests/test_daf3d.py b/tests/test_daf3d.py new file mode 100644 index 0000000000..34e25cc6be --- /dev/null +++ b/tests/test_daf3d.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DAF3D +from monai.utils import optional_import +from tests.utils import test_script_save + +_, has_tv = optional_import("torchvision") + +TEST_CASES = [ + [{"in_channels": 1, "out_channels": 1}, (1, 1, 32, 32, 64), (1, 1, 32, 32, 64)], # single channel 3D, batch 1 + [{"in_channels": 2, "out_channels": 1}, (3, 2, 32, 64, 128), (3, 1, 32, 64, 128)], # two channel 3D, batch 3 + [ + {"in_channels": 2, "out_channels": 2}, + (3, 2, 32, 64, 128), + (3, 2, 32, 64, 128), + ], # two channel 3D, same in & out channels + [{"in_channels": 4, "out_channels": 1}, (5, 4, 35, 35, 35), (5, 1, 35, 35, 35)], # four channel 3D, batch 5 + [ + {"in_channels": 4, "out_channels": 4}, + (5, 4, 35, 35, 35), + (5, 4, 35, 35, 35), + ], # four channel 3D, same in & out channels +] + + +@unittest.skipUnless(has_tv, "torchvision not installed") +class TestDAF3D(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(input_param) + net = DAF3D(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @unittest.skip("daf3d: torchscript not currently supported") + def test_script(self): + net = DAF3D(in_channels=1, out_channels=1) + test_data = torch.randn(16, 1, 32, 32) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_quicknat.py b/tests/test_quicknat.py new file mode 100644 index 0000000000..b4b89b7d62 --- /dev/null +++ b/tests/test_quicknat.py @@ -0,0 +1,57 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import Quicknat +from monai.utils import optional_import +from tests.utils import test_script_save + +_, has_se = optional_import("squeeze_and_excitation") + +TEST_CASES = [ + # params, input_shape, expected_shape + [{"num_classes": 1, "num_channels": 1, "num_filters": 1, "se_block": None}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 4, "se_block": None}, (1, 1, 64, 64), (1, 1, 64, 64)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 128, 128), (1, 1, 128, 128)], + [{"num_classes": 4, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 32, 32), (1, 4, 32, 32)], + [{"num_classes": 33, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 32, 32), (1, 33, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "CSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "SSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "CSSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], +] + + +@unittest.skipUnless(has_se, "squeeze_and_excitation not installed") +class TestQuicknat(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(input_param) + net = Quicknat(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + net = Quicknat(num_classes=1, num_channels=1) + test_data = torch.randn(16, 1, 32, 32) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main()