Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1442 add LocalNet #1447

Merged
merged 37 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7579659
Merge remote-tracking branch 'Project-MONAI/master'
kate-sann5100 Jan 6, 2021
ce61c38
1412 add local normalized cross correlation
kate-sann5100 Jan 7, 2021
5cf91d0
1412 add unit test and documentation
kate-sann5100 Jan 7, 2021
9376195
1412 fix bug
kate-sann5100 Jan 7, 2021
ac36a9f
1412 reformat code
kate-sann5100 Jan 7, 2021
ed6c28b
1412 debug type check
kate-sann5100 Jan 7, 2021
43c2f35
1412 use separable filter for speed
kate-sann5100 Jan 9, 2021
9cc4438
Merge branch 'master' into 1412-local-normalized-cross-correlation
kate-sann5100 Jan 9, 2021
f76e3f0
1412 update Union import route
kate-sann5100 Jan 9, 2021
50940c2
1412 fix negative bug and add smooth_nr
kate-sann5100 Jan 9, 2021
cab9b0b
remove temp. code
wyli Jan 10, 2021
6db2528
1412 reformat code
kate-sann5100 Jan 10, 2021
a3dd503
Merge remote-tracking branch 'origin/1412-local-normalized-cross-corr…
kate-sann5100 Jan 10, 2021
af4cab5
1412 remove redundant import
kate-sann5100 Jan 10, 2021
c3b21c2
Merge remote-tracking branch 'Project-MONAI/master'
kate-sann5100 Jan 10, 2021
46124da
Merge branch '1412-local-normalized-cross-correlation'
kate-sann5100 Jan 10, 2021
a74e61c
Merge remote-tracking branch 'Project-MONAI/master'
kate-sann5100 Jan 12, 2021
0396f20
Merge remote-tracking branch 'Project-MONAI/master' into 1442-localnet
kate-sann5100 Jan 13, 2021
8cc4f88
1442 add localnet
kate-sann5100 Jan 13, 2021
3455ee8
1442 add test
kate-sann5100 Jan 13, 2021
16011d7
1442 add documentation
kate-sann5100 Jan 13, 2021
a99782b
1442 add typing
kate-sann5100 Jan 13, 2021
ca5baa2
Merge remote-tracking branch 'Project-MONAI/master' into 1442-localnet
kate-sann5100 Jan 13, 2021
3772414
1442 reformat
kate-sann5100 Jan 13, 2021
c2373e9
1442 reformat
kate-sann5100 Jan 13, 2021
99b3fac
1442 reformat
kate-sann5100 Jan 13, 2021
9437fb3
1442 reformat
kate-sann5100 Jan 13, 2021
7ef82c5
1442 reformat
kate-sann5100 Jan 13, 2021
c510806
1442 reformat
kate-sann5100 Jan 14, 2021
61691e5
1442 reformat
kate-sann5100 Jan 14, 2021
1399f68
1442 reformat
kate-sann5100 Jan 14, 2021
6329035
1442 reformat
kate-sann5100 Jan 14, 2021
6e8b60e
1442 reformat
kate-sann5100 Jan 14, 2021
c51d597
1442 reformat
kate-sann5100 Jan 14, 2021
2f18619
1442 remove initializsation
kate-sann5100 Jan 14, 2021
7bc6e51
Merge branch 'master' into 1442-localnet
wyli Jan 14, 2021
786eb07
1442 update factory calls
kate-sann5100 Jan 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ Blocks
.. autoclass:: Subpixelupsample
.. autoclass:: SubpixelUpSample

`LocalNet DownSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetDownSampleBlock
:members:

`LocalNet UpSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetUpSampleBlock
:members:

`LocalNet Feature Extractor Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetFeatureExtractorBlock
:members:


Layers
Expand Down Expand Up @@ -298,6 +312,11 @@ Nets
.. autoclass:: VNet
:members:

`LocalNet`
~~~~~~~~~~~
.. autoclass:: LocalNet
:members:

`AutoEncoder`
~~~~~~~~~~~~~
.. autoclass:: AutoEncoder
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .downsample import MaxAvgPool
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
from .fcn import FCN, GCN, MCFCN, Refine
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .segresnet_block import ResBlock
from .squeeze_and_excitation import (
ChannelSELayer,
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
padding = same_padding(kernel_size, dilation)
conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]

conv: nn.Module
if is_transposed:
if output_padding is None:
output_padding = stride_minus_kernel_padding(1, strides)
Expand Down
308 changes: 308 additions & 0 deletions monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
from typing import Optional, Sequence, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from monai.networks.blocks import Convolution
from monai.networks.layers import same_padding
from monai.networks.layers.factories import Norm, Pool


def get_conv_block(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int] = 3,
act: Optional[Union[Tuple, str]] = "RELU",
norm: Optional[Union[Tuple, str]] = "BATCH",
) -> nn.Module:
padding = same_padding(kernel_size)
return Convolution(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
act=act,
norm=norm,
bias=False,
conv_only=False,
padding=padding,
)


def get_conv_layer(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int] = 3,
) -> nn.Module:
padding = same_padding(kernel_size)
return Convolution(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
bias=False,
conv_only=True,
padding=padding,
)


def get_deconv_block(
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> nn.Module:
return Convolution(
dimensions=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
strides=2,
act="RELU",
norm="BATCH",
bias=False,
is_transposed=True,
padding=1,
output_padding=1,
)


class ResidualBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
) -> None:
super(ResidualBlock, self).__init__()
if in_channels != out_channels:
raise ValueError(
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.conv_block = get_conv_block(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
)
self.conv = get_conv_layer(
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
self.relu = nn.ReLU()

def forward(self, x) -> torch.Tensor:
out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x)
return out


class LocalNetResidualBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> None:
super(LocalNetResidualBlock, self).__init__()
if in_channels != out_channels:
raise ValueError(
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.conv_layer = get_conv_layer(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
)
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
self.relu = nn.ReLU()

def forward(self, x, mid) -> torch.Tensor:
out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid)
return out


class LocalNetDownSampleBlock(nn.Module):
"""
A down-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
Raises:
NotImplementedError: when ``kernel_size`` is even
"""
super(LocalNetDownSampleBlock, self).__init__()
self.conv_block = get_conv_block(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.residual_block = ResidualBlock(
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.max_pool = Pool[Pool.MAX, spatial_dims](
kernel_size=2,
)

def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Halves the spatial dimensions.
A tuple of (x, mid) is returned:

- x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),
- mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3])

Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])

Raises:
ValueError: when input spatial dimensions are not even.
"""
for i in x.shape[2:]:
if i % 2 != 0:
raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}")
x = self.conv_block(x)
mid = self.residual_block(x)
x = self.max_pool(mid)
return x, mid


class LocalNetUpSampleBlock(nn.Module):
"""
A up-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
Raises:
ValueError: when ``in_channels != 2 * out_channels``
"""
super(LocalNetUpSampleBlock, self).__init__()
self.deconv_block = get_deconv_block(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
)
self.conv_block = get_conv_block(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
)
self.residual_block = LocalNetResidualBlock(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
)
if in_channels / out_channels != 2:
raise ValueError(
f"expecting in_channels == 2 * out_channels, "
f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.out_channels = out_channels

def addictive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:])
# [(batch, out_channels, ...), (batch, out_channels, ...)]
x = x.split(split_size=int(self.out_channels), dim=1)
# (batch, out_channels, ...)
out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1)
return out

def forward(self, x, mid) -> torch.Tensor:
"""
Halves the channel and doubles the spatial dimensions.

Args:
x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
mid: mid-level feature saved during down-sampling,
in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3])

Raises:
ValueError: when ``midsize != insize * 2``
"""
for i, j in zip(x.shape[2:], mid.shape[2:]):
if j != 2 * i:
raise ValueError(
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
f"got x of shape {x.shape}, mid of shape {mid.shape}"
)
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
r1 = h0 + mid
r2 = self.conv_block(h0)
out: torch.Tensor = self.residual_block(r2, r1)
return out


class LocalNetFeatureExtractorBlock(nn.Module):
"""
A feature-extraction module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
act: Optional[Union[Tuple, str]] = "RELU",
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
act: activation type and arguments. Defaults to ReLU.
kernel_initializer: kernel initializer. Defaults to None.
"""
super(LocalNetFeatureExtractorBlock, self).__init__()
self.conv_block = get_conv_block(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None
)

def forward(self, x) -> torch.Tensor:
"""
Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
"""
out: torch.Tensor = self.conv_block(x)
return out
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .localnet import LocalNet
from .regressor import Regressor
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154
Expand Down
Loading