diff --git a/docs/source/networks.rst b/docs/source/networks.rst index ed17d815b4..420da311d2 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -119,6 +119,20 @@ Blocks .. autoclass:: Subpixelupsample .. autoclass:: SubpixelUpSample +`LocalNet DownSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetDownSampleBlock + :members: + +`LocalNet UpSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetUpSampleBlock + :members: + +`LocalNet Feature Extractor Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetFeatureExtractorBlock + :members: Layers @@ -298,6 +312,11 @@ Nets .. autoclass:: VNet :members: +`LocalNet` +~~~~~~~~~~~ +.. autoclass:: LocalNet + :members: + `AutoEncoder` ~~~~~~~~~~~~~ .. autoclass:: AutoEncoder diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 10b13f619c..c33feb4e2b 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -16,6 +16,7 @@ from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine +from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .segresnet_block import ResBlock from .squeeze_and_excitation import ( ChannelSELayer, diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 5e2dcf163c..7bfb3b47e4 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -94,6 +94,7 @@ def __init__( padding = same_padding(kernel_size, dilation) conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + conv: nn.Module if is_transposed: if output_padding is None: output_padding = stride_minus_kernel_padding(1, strides) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py new file mode 100644 index 0000000000..ee7fac0690 --- /dev/null +++ b/monai/networks/blocks/localnet_block.py @@ -0,0 +1,308 @@ +from typing import Optional, Sequence, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks import Convolution +from monai.networks.layers import same_padding +from monai.networks.layers.factories import Norm, Pool + + +def get_conv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + act: Optional[Union[Tuple, str]] = "RELU", + norm: Optional[Union[Tuple, str]] = "BATCH", +) -> nn.Module: + padding = same_padding(kernel_size) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + act=act, + norm=norm, + bias=False, + conv_only=False, + padding=padding, + ) + + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, +) -> nn.Module: + padding = same_padding(kernel_size) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + bias=False, + conv_only=True, + padding=padding, + ) + + +def get_deconv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, +) -> nn.Module: + return Convolution( + dimensions=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=2, + act="RELU", + norm="BATCH", + bias=False, + is_transposed=True, + padding=1, + output_padding=1, + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ) -> None: + super(ResidualBlock, self).__init__() + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + ) + self.conv = get_conv_layer( + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) + self.relu = nn.ReLU() + + def forward(self, x) -> torch.Tensor: + out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x) + return out + + +class LocalNetResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ) -> None: + super(LocalNetResidualBlock, self).__init__() + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.conv_layer = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) + self.relu = nn.ReLU() + + def forward(self, x, mid) -> torch.Tensor: + out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid) + return out + + +class LocalNetDownSampleBlock(nn.Module): + """ + A down-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + Raises: + NotImplementedError: when ``kernel_size`` is even + """ + super(LocalNetDownSampleBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.residual_block = ResidualBlock( + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size + ) + self.max_pool = Pool[Pool.MAX, spatial_dims]( + kernel_size=2, + ) + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Halves the spatial dimensions. + A tuple of (x, mid) is returned: + + - x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]), + - mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]) + + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + + Raises: + ValueError: when input spatial dimensions are not even. + """ + for i in x.shape[2:]: + if i % 2 != 0: + raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}") + x = self.conv_block(x) + mid = self.residual_block(x) + x = self.max_pool(mid) + return x, mid + + +class LocalNetUpSampleBlock(nn.Module): + """ + A up-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + Raises: + ValueError: when ``in_channels != 2 * out_channels`` + """ + super(LocalNetUpSampleBlock, self).__init__() + self.deconv_block = get_deconv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + self.residual_block = LocalNetResidualBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + if in_channels / out_channels != 2: + raise ValueError( + f"expecting in_channels == 2 * out_channels, " + f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.out_channels = out_channels + + def addictive_upsampling(self, x, mid) -> torch.Tensor: + x = F.interpolate(x, mid.shape[2:]) + # [(batch, out_channels, ...), (batch, out_channels, ...)] + x = x.split(split_size=int(self.out_channels), dim=1) + # (batch, out_channels, ...) + out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1) + return out + + def forward(self, x, mid) -> torch.Tensor: + """ + Halves the channel and doubles the spatial dimensions. + + Args: + x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + mid: mid-level feature saved during down-sampling, + in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) + + Raises: + ValueError: when ``midsize != insize * 2`` + """ + for i, j in zip(x.shape[2:], mid.shape[2:]): + if j != 2 * i: + raise ValueError( + "expecting mid spatial dimensions be exactly the double of x spatial dimensions, " + f"got x of shape {x.shape}, mid of shape {mid.shape}" + ) + h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) + r1 = h0 + mid + r2 = self.conv_block(h0) + out: torch.Tensor = self.residual_block(r2, r1) + return out + + +class LocalNetFeatureExtractorBlock(nn.Module): + """ + A feature-extraction module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Optional[Union[Tuple, str]] = "RELU", + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + act: activation type and arguments. Defaults to ReLU. + kernel_initializer: kernel initializer. Defaults to None. + """ + super(LocalNetFeatureExtractorBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None + ) + + def forward(self, x) -> torch.Tensor: + """ + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + """ + out: torch.Tensor = self.conv_block(x) + return out diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 6c7570ebf9..a9308de9d7 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .localnet import LocalNet from .regressor import Regressor from .segresnet import SegResNet, SegResNetVAE from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py new file mode 100644 index 0000000000..1bb3dcbc21 --- /dev/null +++ b/monai/networks/nets/localnet.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks.localnet_block import ( + LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, + LocalNetUpSampleBlock, + get_conv_block, +) + + +class LocalNet(nn.Module): + """ + Reimplementation of LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channel_initial: int, + extract_levels: List[int], + out_activation: Optional[Union[Tuple, str]], + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_channel_initial: number of initial channels, + extract_levels: number of extraction levels, + out_activation: activation to use at end layer, + """ + super(LocalNet, self).__init__() + self.extract_levels = extract_levels + self.extract_max_level = max(self.extract_levels) # E + self.extract_min_level = min(self.extract_levels) # D + + num_channels = [ + num_channel_initial * (2 ** level) for level in range(self.extract_max_level + 1) + ] # level 0 to E + + self.downsample_blocks = nn.ModuleList( + [ + LocalNetDownSampleBlock( + spatial_dims=spatial_dims, + in_channels=in_channels if i == 0 else num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=7 if i == 0 else 3, + ) + for i in range(self.extract_max_level) + ] + ) # level 0 to self.extract_max_level - 1 + self.conv3d_block = get_conv_block( + spatial_dims=spatial_dims, in_channels=num_channels[-2], out_channels=num_channels[-1] + ) # self.extract_max_level + + self.upsample_blocks = nn.ModuleList( + [ + LocalNetUpSampleBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level + 1], + out_channels=num_channels[level], + ) + for level in range(self.extract_max_level - 1, self.extract_min_level - 1, -1) + ] + ) # self.extract_max_level - 1 to self.extract_min_level + + self.extract_layers = nn.ModuleList( + [ + # if kernels are not initialized by zeros, with init NN, extract may be too large + LocalNetFeatureExtractorBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level], + out_channels=out_channels, + act=out_activation, + ) + for level in self.extract_levels + ] + ) + + def forward(self, x) -> torch.Tensor: + image_size = x.shape[2:] + for size in image_size: + if size % (2 ** self.extract_max_level) != 0: + raise ValueError( + f"given extract_max_level {self.extract_max_level}, " + f"all input spatial dimension must be devidable by {2 ** self.extract_max_level}, " + f"got input of size {image_size}" + ) + mid_features = [] # 0 -> self.extract_max_level - 1 + for downsample_block in self.downsample_blocks: + x, mid = downsample_block(x) + mid_features.append(mid) + x = self.conv3d_block(x) # self.extract_max_level + + decoded_features = [x] + for idx, upsample_block in enumerate(self.upsample_blocks): + x = upsample_block(x, mid_features[-idx - 1]) + decoded_features.append(x) # self.extract_max_level -> self.extract_min_level + + output = torch.mean( + torch.stack( + [ + F.interpolate( + extract_layer(decoded_features[self.extract_max_level - self.extract_levels[idx]]), + size=image_size, + ) + for idx, extract_layer in enumerate(self.extract_layers) + ], + dim=-1, + ), + dim=-1, + ) + return output diff --git a/tests/test_localnet.py b/tests/test_localnet.py new file mode 100644 index 0000000000..d4f812e811 --- /dev/null +++ b/tests/test_localnet.py @@ -0,0 +1,83 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.localnet import LocalNet +from tests.utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +param_variations_2d = { + "spatial_dims": 2, + "in_channels": 2, + "out_channels": 2, + "num_channel_initial": 16, + "extract_levels": [0, 1, 2], + "out_activation": ["sigmoid", None], +} + +TEST_CASE_LOCALNET_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 2, + "out_channels": 2, + "num_channel_initial": 16, + "extract_levels": [0, 1, 2], + "out_activation": act, + }, + (1, 2, 16, 16), + (1, 2, 16, 16), + ] + for act in ["sigmoid", None] +] + +TEST_CASE_LOCALNET_3D = [] +for in_channels in [2, 3]: + for out_channels in [1, 3]: + for num_channel_initial in [4, 16, 32]: + for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: + for out_activation in ["sigmoid", None]: + TEST_CASE_LOCALNET_3D.append( + [ + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_activation": out_activation, + }, + (1, in_channels, 16, 16, 16), + (1, out_channels, 16, 16, 16), + ] + ) + + +class TestDynUNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D) + def test_shape(self, input_param, input_shape, expected_shape): + net = LocalNet(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_shape(self): + with self.assertRaisesRegex(ValueError, ""): + input_param, _, _ = TEST_CASE_LOCALNET_2D[0] + input_shape = (1, input_param["in_channels"], 17, 17) + net = LocalNet(**input_param).to(device) + net.forward(torch.randn(input_shape).to(device)) + + def test_script(self): + input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0] + net = LocalNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py new file mode 100644 index 0000000000..af5ef19222 --- /dev/null +++ b/tests/test_localnet_block.py @@ -0,0 +1,98 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.localnet_block import ( + LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, + LocalNetUpSampleBlock, +) + +TEST_CASE_DOWN_SAMPLE = [ + [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] +] + +TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] + +TEST_CASE_EXTRACT = [ + [ + { + "spatial_dims": spatial_dims, + "in_channels": 2, + "out_channels": 3, + "act": act, + } + ] + for spatial_dims, act in zip([2, 3], ["sigmoid", None]) +] + +in_size = 4 + + +class TestLocalNetDownSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expect_mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + expect_x_shape = (1, input_param["out_channels"], *([in_size / 2] * input_param["spatial_dims"])) + with eval_mode(net): + x, mid = net(torch.randn(input_shape)) + self.assertEqual(x.shape, expect_x_shape) + self.assertEqual(mid.shape, expect_mid_shape) + + def test_ill_arg(self): + # even kernel_size + with self.assertRaises(NotImplementedError): + LocalNetDownSampleBlock(spatial_dims=2, in_channels=2, out_channels=4, kernel_size=4) + + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([5] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape)) + + +class TestLocalNetUpSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size * 2] * input_param["spatial_dims"])) + expected_shape = mid_shape + with eval_mode(net): + result = net(torch.randn(input_shape), torch.randn(mid_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + # channel unmatch + with self.assertRaises(ValueError): + LocalNetUpSampleBlock(spatial_dims=2, in_channels=2, out_channels=2) + + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape), torch.randn(mid_shape)) + + +class TestExtractBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_EXTRACT) + def test_shape(self, input_param): + net = LocalNetFeatureExtractorBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expected_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main()