Skip to content

Commit

Permalink
1442 add LocalNet (#1447)
Browse files Browse the repository at this point in the history
* 1442 add localnet

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 add test

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 add documentation

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 add typing

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 reformat

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 reformat

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 reformat

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 remove initializsation

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 update factory calls

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>
  • Loading branch information
kate-sann5100 authored Jan 15, 2021
1 parent c0983c4 commit e115d6c
Show file tree
Hide file tree
Showing 8 changed files with 637 additions and 0 deletions.
19 changes: 19 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ Blocks
.. autoclass:: Subpixelupsample
.. autoclass:: SubpixelUpSample

`LocalNet DownSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetDownSampleBlock
:members:

`LocalNet UpSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetUpSampleBlock
:members:

`LocalNet Feature Extractor Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetFeatureExtractorBlock
:members:


Layers
Expand Down Expand Up @@ -298,6 +312,11 @@ Nets
.. autoclass:: VNet
:members:

`LocalNet`
~~~~~~~~~~~
.. autoclass:: LocalNet
:members:

`AutoEncoder`
~~~~~~~~~~~~~
.. autoclass:: AutoEncoder
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .downsample import MaxAvgPool
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
from .fcn import FCN, GCN, MCFCN, Refine
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .segresnet_block import ResBlock
from .squeeze_and_excitation import (
ChannelSELayer,
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
padding = same_padding(kernel_size, dilation)
conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]

conv: nn.Module
if is_transposed:
if output_padding is None:
output_padding = stride_minus_kernel_padding(1, strides)
Expand Down
308 changes: 308 additions & 0 deletions monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
from typing import Optional, Sequence, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from monai.networks.blocks import Convolution
from monai.networks.layers import same_padding
from monai.networks.layers.factories import Norm, Pool


def get_conv_block(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int] = 3,
act: Optional[Union[Tuple, str]] = "RELU",
norm: Optional[Union[Tuple, str]] = "BATCH",
) -> nn.Module:
padding = same_padding(kernel_size)
return Convolution(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
act=act,
norm=norm,
bias=False,
conv_only=False,
padding=padding,
)


def get_conv_layer(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int] = 3,
) -> nn.Module:
padding = same_padding(kernel_size)
return Convolution(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
bias=False,
conv_only=True,
padding=padding,
)


def get_deconv_block(
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> nn.Module:
return Convolution(
dimensions=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
strides=2,
act="RELU",
norm="BATCH",
bias=False,
is_transposed=True,
padding=1,
output_padding=1,
)


class ResidualBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
) -> None:
super(ResidualBlock, self).__init__()
if in_channels != out_channels:
raise ValueError(
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.conv_block = get_conv_block(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
)
self.conv = get_conv_layer(
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
self.relu = nn.ReLU()

def forward(self, x) -> torch.Tensor:
out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x)
return out


class LocalNetResidualBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> None:
super(LocalNetResidualBlock, self).__init__()
if in_channels != out_channels:
raise ValueError(
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.conv_layer = get_conv_layer(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
)
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
self.relu = nn.ReLU()

def forward(self, x, mid) -> torch.Tensor:
out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid)
return out


class LocalNetDownSampleBlock(nn.Module):
"""
A down-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
Raises:
NotImplementedError: when ``kernel_size`` is even
"""
super(LocalNetDownSampleBlock, self).__init__()
self.conv_block = get_conv_block(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.residual_block = ResidualBlock(
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
)
self.max_pool = Pool[Pool.MAX, spatial_dims](
kernel_size=2,
)

def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Halves the spatial dimensions.
A tuple of (x, mid) is returned:
- x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),
- mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3])
Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
Raises:
ValueError: when input spatial dimensions are not even.
"""
for i in x.shape[2:]:
if i % 2 != 0:
raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}")
x = self.conv_block(x)
mid = self.residual_block(x)
x = self.max_pool(mid)
return x, mid


class LocalNetUpSampleBlock(nn.Module):
"""
A up-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
Raises:
ValueError: when ``in_channels != 2 * out_channels``
"""
super(LocalNetUpSampleBlock, self).__init__()
self.deconv_block = get_deconv_block(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
)
self.conv_block = get_conv_block(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
)
self.residual_block = LocalNetResidualBlock(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
)
if in_channels / out_channels != 2:
raise ValueError(
f"expecting in_channels == 2 * out_channels, "
f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.out_channels = out_channels

def addictive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:])
# [(batch, out_channels, ...), (batch, out_channels, ...)]
x = x.split(split_size=int(self.out_channels), dim=1)
# (batch, out_channels, ...)
out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1)
return out

def forward(self, x, mid) -> torch.Tensor:
"""
Halves the channel and doubles the spatial dimensions.
Args:
x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
mid: mid-level feature saved during down-sampling,
in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3])
Raises:
ValueError: when ``midsize != insize * 2``
"""
for i, j in zip(x.shape[2:], mid.shape[2:]):
if j != 2 * i:
raise ValueError(
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
f"got x of shape {x.shape}, mid of shape {mid.shape}"
)
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
r1 = h0 + mid
r2 = self.conv_block(h0)
out: torch.Tensor = self.residual_block(r2, r1)
return out


class LocalNetFeatureExtractorBlock(nn.Module):
"""
A feature-extraction module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
act: Optional[Union[Tuple, str]] = "RELU",
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
act: activation type and arguments. Defaults to ReLU.
kernel_initializer: kernel initializer. Defaults to None.
"""
super(LocalNetFeatureExtractorBlock, self).__init__()
self.conv_block = get_conv_block(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None
)

def forward(self, x) -> torch.Tensor:
"""
Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
"""
out: torch.Tensor = self.conv_block(x)
return out
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .localnet import LocalNet
from .regressor import Regressor
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154
Expand Down
Loading

0 comments on commit e115d6c

Please sign in to comment.