From ce61c38c377059f27deef07b26433669ab45c850 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 00:00:21 +0000 Subject: [PATCH 01/28] 1412 add local normalized cross correlation Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 157 ++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 monai/losses/image_dissimilarity.py diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py new file mode 100644 index 0000000000..557a757e79 --- /dev/null +++ b/monai/losses/image_dissimilarity.py @@ -0,0 +1,157 @@ +import torch +from torch.nn.modules.loss import _Loss + +from torch.nn import functional as F + +from monai.utils import LossReduction, Union + + +conv_dict = { + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d +} + +EPS = 1e-7 + + +class LocalNormalizedCrossCorrelation(_Loss): + """ + Local squared zero-normalized cross-correlation. + The loss is based on a moving kernel/window over the y_true/y_pred, + within the window the square of zncc is calculated. + The kernel can be a rectangular / triangular / gaussian window. + The final loss is the averaged loss over all windows. + + Adapted from: + https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + in_channels: int, + dim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + in_channels: number of input channels + dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + self.in_channels = in_channels + self.dim = dim + if self.dim not in [1, 2, 3]: + raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.kernel_size = kernel_size + if kernel_type == "rectangular": + self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + elif kernel_type == "triangular": + self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + elif kernel_type == "gaussian": + self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + else: + raise ValueError( + f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' + ) + + def make_rectangular_kernel(self): + shape = [1, self.in_channels] + [self.kernel_size] * self.dim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + + def make_triangular_kernel(self): + fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) + f1 = torch.ones( + [1, 1] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, 1, D, H, W) + f2 = torch.ones( + [1, self.in_channels] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, in_channels, D, H, W) + # (1, 1, D, H, W) -> (1, in_channels, D, H, W) + fn = conv_dict[self.dim] + kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + + def make_gaussian_kernel(self): + mean = (self.kernel_size - 1) / 2.0 + sigma = self.kernel_size / 3 + + grid_dim = torch.arange(0, self.kernel_size) + grid_dim_ch = torch.arange(0, self.in_channel) + + if self.dim == 1: + grid = torch.meshgrid(grid_dim_ch, grid_dim) + elif self.dim == 2: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) + elif self.dim == 3: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + else: + raise ValueError + + grid = torch.stack(grid, dim=-1).to(dtype=torch.float) + kernel = torch.exp( + -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) + ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + assert ( + target.shape == input.shape + ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + + t2, p2, tp = target ** 2, input ** 2, target * input + + # sum over kernel + fn = conv_dict[self.dim] + t_sum = fn(target, weight=self.kernel, padding=self.padding) + p_sum = fn(input, weight=self.kernel, padding=self.paddin) + t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) + p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) + tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + + # average over kernel + t_avg = t_sum / self.kernel_vol + p_avg = p_sum / self.kernel_vol + + # normalized cross correlation between t and p + # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] + # denoted by num / denom + # assume we sum over N values + # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]] + # = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * mean[p] = cross + # the following is actually squared ncc + cross = tp_sum - p_avg * t_sum + t_var = t2_sum - t_avg * t_sum # std[t] ** 2 + p_var = p2_sum - p_avg * p_sum # std[p] ** 2 + ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(ncc) # sum over the batch and channel dims + if self.reduction == LossReduction.NONE.value: + return ncc + if self.reduction == LossReduction.MEAN.value: + return torch.mean(ncc) # average over the batch and channel dims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From 5cf91d0463f3b52d36a193dfea290263d9ad0819 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:02:13 +0000 Subject: [PATCH 02/28] 1412 add unit test and documentation Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/image_dissimilarity.py | 94 +++++++------- ...local_normalized_cross_correlation_loss.py | 121 ++++++++++++++++++ 4 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 tests/test_local_normalized_cross_correlation_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index d2c8e02ca4..fc86c9cd05 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -65,3 +65,8 @@ Registration Losses ~~~~~~~~~~~~~~~~~~~ .. autoclass:: BendingEnergyLoss :members: + +`LocalNormalizedCrossCorrelationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNormalizedCrossCorrelationLoss + :members: diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d4e21f900c..42ccd0f65a 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -22,4 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss +from .image_dissimilarity import LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 557a757e79..26ee7ea718 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,21 +1,15 @@ import torch -from torch.nn.modules.loss import _Loss - from torch.nn import functional as F +from torch.nn.modules.loss import _Loss from monai.utils import LossReduction, Union - -conv_dict = { - 1: F.conv1d, - 2: F.conv2d, - 3: F.conv3d -} +conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} EPS = 1e-7 -class LocalNormalizedCrossCorrelation(_Loss): +class LocalNormalizedCrossCorrelationLoss(_Loss): """ Local squared zero-normalized cross-correlation. The loss is based on a moving kernel/window over the y_true/y_pred, @@ -29,17 +23,17 @@ class LocalNormalizedCrossCorrelation(_Loss): """ def __init__( - self, - in_channels: int, - dim: int = 3, - kernel_size: int = 9, - kernel_type: str = "rectangular", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + self, + in_channels: int, + ndim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ Args: in_channels: number of input channels - dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -51,9 +45,9 @@ def __init__( """ super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels - self.dim = dim - if self.dim not in [1, 2, 3]: - raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.ndim = ndim + if self.ndim not in [1, 2, 3]: + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -67,21 +61,17 @@ def __init__( ) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.dim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + shape = [1, self.in_channels] + [self.kernel_size] * self.ndim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones( - [1, 1] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, 1, D, H, W) - f2 = torch.ones( - [1, self.in_channels] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, in_channels, D, H, W) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) + f2 = ( + torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize + ) # (1, in_channels, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) @@ -90,22 +80,22 @@ def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 sigma = self.kernel_size / 3 - grid_dim = torch.arange(0, self.kernel_size) - grid_dim_ch = torch.arange(0, self.in_channel) + grid_ndim = torch.arange(0, self.kernel_size) + grid_ndim_ch = torch.arange(0, self.in_channels) - if self.dim == 1: - grid = torch.meshgrid(grid_dim_ch, grid_dim) - elif self.dim == 2: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) - elif self.dim == 3: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + if self.ndim == 1: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim) + elif self.ndim == 2: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) + elif self.ndim == 3: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) else: raise ValueError grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp( - -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) - ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( + 0 + ) # (1, in_channel, kernel_size, kernel_size, kernel_size) return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -116,6 +106,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ + assert ( + input.shape[1] == self.in_channels + ), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}" + assert ( + input.ndim - 2 == self.ndim + ), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}" assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" @@ -123,12 +119,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.paddin) - t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) - p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) - tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + p_sum = fn(input, weight=self.kernel, padding=self.padding) + t2_sum = fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -149,9 +145,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc) # sum over the batch and channel dims + return -torch.sum(ncc) # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return ncc + return -ncc if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc) # average over the batch and channel dims + return -torch.mean(ncc) # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py new file mode 100644 index 0000000000..332689be06 --- /dev/null +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -0,0 +1,121 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.06062524, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.9368649, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.50272596, + ], +] + + +class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + + def test_ill_shape(self): + loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) + # in_channel unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) + # ndim unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) + # input, target shape unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) + + def test_ill_opts(self): + input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(input, target) + + +if __name__ == "__main__": + unittest.main() From 9376195b34df1eab0fc0f94b251102184e23864f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:37:55 +0000 Subject: [PATCH 03/28] 1412 fix bug Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 26ee7ea718..42f0cfd9d7 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -43,7 +43,7 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ - super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: From ac36a9fc877882cf8da48378479782a9e58c7b8e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:00:23 +0000 Subject: [PATCH 04/28] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 58 ++++++++++++------- ...local_normalized_cross_correlation_loss.py | 15 ++++- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 42f0cfd9d7..232b8e0600 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,3 +1,14 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss @@ -6,8 +17,6 @@ conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} -EPS = 1e-7 - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -29,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_dr: float = 1e-7, ) -> None: """ Args: @@ -42,12 +52,14 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") + self.fn = conv_dict[self.ndim] self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -59,26 +71,29 @@ def __init__( raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): shape = [1, self.in_channels] + [self.kernel_size] * self.ndim return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): - fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) - f2 = ( - torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize - ) # (1, in_channels, D, H, W) + fsize = int((self.kernel_size + 1) // 2) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) + f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) + f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) + # (in_channels, 1, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.ndim] - kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + padding_needed = max(fsize - 1, 0) + padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim + f1 = F.pad(f1, padding) + kernel = self.fn(f1, f2) - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3 + sigma = self.kernel_size / 3.0 grid_ndim = torch.arange(0, self.kernel_size) grid_ndim_ch = torch.arange(0, self.in_channels) @@ -96,7 +111,7 @@ def make_gaussian_kernel(self): kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( 0 ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -119,12 +134,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.ndim] - t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.padding) - t2_sum = fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = fn(tp, weight=self.kernel, padding=self.padding) + t_sum = self.fn(target, weight=self.kernel, padding=self.padding) + p_sum = self.fn(input, weight=self.kernel, padding=self.padding) + t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -142,12 +156,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return -torch.sum(ncc) # sum over the batch and channel ndims + return -torch.sum(ncc).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return -ncc + return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return -torch.mean(ncc) # average over the batch and channel ndims + return torch.mean(ncc).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 332689be06..beb1a1dca2 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -1,3 +1,14 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import numpy as np @@ -72,7 +83,7 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, @@ -90,7 +101,7 @@ ] -class TestBendingEnergy(unittest.TestCase): +class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) From ed6c28bca7393a527b5bc21a0442f0a08333d33c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:20:47 +0000 Subject: [PATCH 05/28] 1412 debug type check Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 232b8e0600..4e64d81eb6 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -156,7 +156,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) + ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: return -torch.sum(ncc).neg() # sum over the batch and channel ndims From 43c2f355416b249a34adfa87367aabd5f1b381f0 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:42:45 +0000 Subject: [PATCH 06/28] 1412 use separable filter for speed Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 75 +++++++------------ monai/networks/layers/simplelayers.py | 6 ++ ...local_normalized_cross_correlation_loss.py | 10 ++- 3 files changed, 41 insertions(+), 50 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 4e64d81eb6..bd9a06e187 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -13,10 +13,9 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss +from monai.networks.layers import gaussian_1d, separable_filtering from monai.utils import LossReduction, Union -conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -44,7 +43,7 @@ def __init__( Args: in_channels: number of input channels ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. - kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -56,62 +55,46 @@ def __init__( """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels + self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") - self.fn = conv_dict[self.ndim] + self.kernel_size = kernel_size + if self.kernel_size % 2 == 0: + raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") + if kernel_type == "rectangular": - self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + self.kernel = self.make_rectangular_kernel() elif kernel_type == "triangular": - self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + self.kernel = self.make_triangular_kernel() elif kernel_type == "gaussian": - self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + self.kernel = self.make_gaussian_kernel() else: raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + + self.kernel_vol = torch.sum(self.kernel) ** self.ndim self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.ndim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) + return torch.ones(self.kernel_size) def make_triangular_kernel(self): - fsize = int((self.kernel_size + 1) // 2) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) - f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) - f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) - # (in_channels, 1, D, H, W) - # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - padding_needed = max(fsize - 1, 0) - padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim - f1 = F.pad(f1, padding) - kernel = self.fn(f1, f2) - - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) + fsize = (self.kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (self.kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) def make_gaussian_kernel(self): - mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3.0 - - grid_ndim = torch.arange(0, self.kernel_size) - grid_ndim_ch = torch.arange(0, self.in_channels) - - if self.ndim == 1: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim) - elif self.ndim == 2: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) - elif self.ndim == 3: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) - else: - raise ValueError - - grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( - 0 - ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) + sigma = torch.tensor(self.kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=self.kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[: self.kernel_size] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,11 +117,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - t_sum = self.fn(target, weight=self.kernel, padding=self.padding) - p_sum = self.fn(input, weight=self.kernel, padding=self.padding) - t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) + t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) # average over kernel t_avg = t_sum / self.kernel_vol diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 48012dfb1c..8ff9464145 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,3 +365,9 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +if __name__ == "__main__": + input = torch.ones((1, 3, 3, 3)) + kernels = [torch.ones(1, 3)] * 2 + print(separable_filtering(input, kernels)) diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index beb1a1dca2..fe455b0597 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -83,12 +83,12 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.9368649, + -0.923356, ], [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, @@ -96,7 +96,7 @@ "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.50272596, + -1.306177, ], ] @@ -105,7 +105,7 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) @@ -122,6 +122,8 @@ def test_ill_shape(self): def test_ill_opts(self): input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) with self.assertRaisesRegex(ValueError, ""): From f76e3f04bcc4f72eb3fb989261e671bcb5252a43 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:49:19 +0000 Subject: [PATCH 07/28] 1412 update Union import route Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index bd9a06e187..093aa2546b 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,13 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering -from monai.utils import LossReduction, Union +from monai.utils import LossReduction class LocalNormalizedCrossCorrelationLoss(_Loss): From 50940c2982da3a97f33abb09b21b140641a47456 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 03:02:08 +0000 Subject: [PATCH 08/28] 1412 fix negative bug and add smooth_nr Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 093aa2546b..9a9a0291dd 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -38,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: """ @@ -52,6 +53,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -77,6 +79,7 @@ def __init__( ) self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): @@ -140,11 +143,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return -torch.sum(ncc).neg() # sum over the batch and channel ndims + return torch.sum(ncc).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: From cab9b0beb4f514c68f62e6b950d428963d209ab7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 10 Jan 2021 15:44:48 +0000 Subject: [PATCH 09/28] remove temp. code Signed-off-by: Wenqi Li --- docs/source/losses.rst | 2 +- monai/networks/layers/simplelayers.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index fc86c9cd05..462c303e65 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -67,6 +67,6 @@ Registration Losses :members: `LocalNormalizedCrossCorrelationLoss` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss :members: diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 55189e3e9b..ba60f4eca4 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,9 +365,3 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) - - -if __name__ == "__main__": - input = torch.ones((1, 3, 3, 3)) - kernels = [torch.ones(1, 3)] * 2 - print(separable_filtering(input, kernels)) From 6db25289fd07ae4cfc3b83a042f127acd94e0042 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sun, 10 Jan 2021 17:07:45 +0000 Subject: [PATCH 10/28] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 64 ++++++++++--------- ...local_normalized_cross_correlation_loss.py | 12 ++++ 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 9a9a0291dd..65b53309f5 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Tuple, Union import torch from torch.nn import functional as F @@ -18,6 +18,34 @@ from monai.utils import LossReduction +def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: + return torch.ones(kernel_size) + + +def make_triangular_kernel(kernel_size: int) -> torch.Tensor: + fsize = (kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) + + +def make_gaussian_kernel(kernel_size: int) -> torch.Tensor: + sigma = torch.tensor(kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[:kernel_size] + + +kernel_dict = { + "rectangular": make_rectangular_kernel, + "triangular": make_triangular_kernel, + "gaussian": make_gaussian_kernel, +} + + class LocalNormalizedCrossCorrelationLoss(_Loss): """ Local squared zero-normalized cross-correlation. @@ -53,7 +81,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. - smooth_nr: a small constant added to the numerator to avoid zero. + smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -67,39 +95,15 @@ def __init__( if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") - if kernel_type == "rectangular": - self.kernel = self.make_rectangular_kernel() - elif kernel_type == "triangular": - self.kernel = self.make_triangular_kernel() - elif kernel_type == "gaussian": - self.kernel = self.make_gaussian_kernel() - else: + if kernel_type not in kernel_dict.keys(): raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) - + self.kernel = kernel_dict[kernel_type](self.kernel_size) self.kernel_vol = torch.sum(self.kernel) ** self.ndim self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def make_rectangular_kernel(self): - return torch.ones(self.kernel_size) - - def make_triangular_kernel(self): - fsize = (self.kernel_size + 1) // 2 - if fsize % 2 == 0: - fsize -= 1 - f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) - padding = (self.kernel_size - fsize) // 2 + fsize // 2 - return F.conv1d(f, f, padding=padding).reshape(-1) - - def make_gaussian_kernel(self): - sigma = torch.tensor(self.kernel_size / 3.0) - kernel = gaussian_1d(sigma=sigma, truncated=self.kernel_size // 2, approx="sampled", normalize=False) * ( - 2.5066282 * sigma - ) - return kernel[: self.kernel_size] - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -147,9 +151,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc).neg() # sum over the batch and channel ndims + return torch.sum(ncc).neg() # sum over the batch and spatial ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc).neg() # average over the batch and channel ndims + return torch.mean(ncc).neg() # average over the batch and spatial ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index fe455b0597..cb2f446dfc 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -74,6 +74,14 @@ }, -1.0, ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian", "reduction": "sum"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + }, + -6.0, + ], [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, { @@ -122,6 +130,10 @@ def test_ill_shape(self): def test_ill_opts(self): input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(input, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) with self.assertRaisesRegex(ValueError, ""): From af4cab559fc56ed47139f6546b443273fbafc605 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sun, 10 Jan 2021 17:35:07 +0000 Subject: [PATCH 11/28] 1412 remove redundant import Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 65b53309f5..d42303e154 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch.nn import functional as F From 8cc4f88e9ade6c7fa8361489bbe9a76ddad1a098 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 14:34:36 +0000 Subject: [PATCH 12/28] 1442 add localnet Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 245 ++++++++++++++++++++++++ monai/networks/nets/localnet.py | 126 ++++++++++++ tests/test_localnet.py | 82 ++++++++ 3 files changed, 453 insertions(+) create mode 100644 monai/networks/blocks/localnet_block.py create mode 100644 monai/networks/nets/localnet.py create mode 100644 tests/test_localnet.py diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py new file mode 100644 index 0000000000..01edc45871 --- /dev/null +++ b/monai/networks/blocks/localnet_block.py @@ -0,0 +1,245 @@ +from typing import Union, Sequence, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks import Convolution, get_padding +from monai.networks.layers.factories import batch_factory, maxpooling_factory + + +initializer_dict = { + "zeros": nn.init.zeros_, +} + + +def get_conv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + act: Optional[Union[Tuple, str]] = "RELU", + norm: Optional[Union[Tuple, str]] = "BATCH", +): + padding = get_padding(kernel_size, stride=1) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + act=act, + norm=norm, + bias=False, + conv_only=False, + padding=padding, + ) + + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, +): + padding = get_padding(kernel_size, stride=1) + return Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + bias=False, + conv_only=True, + padding=padding, + ) + + +def get_deconv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, +): + return Convolution( + dimensions=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=2, + act="RELU", + norm="BATCH", + bias=False, + is_transposed=True, + padding=1, + output_padding=1, + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ): + super(ResidualBlock, self).__init__() + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " + f"got in_channels={in_channels}, out_channels={out_channels}") + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + ) + self.conv = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size + ) + self.norm = batch_factory(spatial_dims)(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu( + self.norm( + self.conv( + self.conv_block(x) + ) + ) + x + ) + + +class LocalNetResidualBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ): + super(LocalNetResidualBlock, self).__init__() + # if in_channels != out_channels: + # raise ValueError( + # f"expecting in_channels == out_channels, " + # f"got in_channels={in_channels}, out_channels={out_channels}") + self.conv_layer = get_conv_layer( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + self.norm = batch_factory(spatial_dims)(out_channels) + self.relu = nn.ReLU() + + def forward(self, x, mid): + return self.relu( + self.norm( + self.conv_layer(x) + ) + mid + ) + + +class LocalNetDownSampleBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + ): + super(LocalNetDownSampleBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size + ) + self.residual_block = ResidualBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size + ) + self.max_pool = maxpooling_factory(spatial_dims)( + kernel_size=2, + ) + + def forward(self, x): + x = self.conv_block(x) + mid = self.residual_block(x) + x = self.max_pool(mid) + return x, mid + + +class LocalNetUpSampleBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + ): + super(LocalNetUpSampleBlock, self).__init__() + self.deconv_block = get_deconv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + ) + self.residual_block = LocalNetResidualBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + ) + if in_channels / out_channels != 2: + raise ValueError( + f"expecting in_channels == 2 * out_channels, " + f"got in_channels={in_channels}, out_channels={out_channels}" + ) + self.out_channels = out_channels + + def addictive_upsampling(self, x): + size = torch.Size(torch.tensor(x.shape[2:]) * 2) + x = F.interpolate(x, size) + # [(batch, out_channels, ...), (batch, out_channels, ...)] + x = x.split(split_size=int(self.out_channels), dim=1) + # (batch, out_channels, ...) + x = torch.sum( + torch.stack(x, dim=-1), + dim=-1 + ) + return x + + def forward(self, x, mid): + h0 = self.deconv_block(x) + self.addictive_upsampling(x) + r1 = h0 + mid + r2 = self.conv_block(h0) + return self.residual_block(r2, r1) + + +class ExtractBlock(nn.Module): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Optional[Union[Tuple, str]] = "RELU", + kernel_initializer: Optional[Union[Tuple, str]] = None, + ): + super(ExtractBlock, self).__init__() + self.conv_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + act=act, + norm=None + ) + if kernel_initializer: + initializer_dict[kernel_initializer](self.conv_block.conv.weight) + + def forward(self, x): + x = self.conv_block(x) + return x \ No newline at end of file diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py new file mode 100644 index 0000000000..0320151224 --- /dev/null +++ b/monai/networks/nets/localnet.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Union, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.networks.blocks.localnet_block import LocalNetDownSampleBlock, get_conv_block, LocalNetUpSampleBlock, \ + ExtractBlock + + +class LocalNet(nn.Module): + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channel_initial: int, + extract_levels: List[int], + out_kernel_initializer: str, + out_activation: Optional[Union[Tuple, str]], + control_points: (tuple, None) = None, + **kwargs, + ): + super(LocalNet, self).__init__() + self.extract_levels = extract_levels + self.extract_max_level = max(self.extract_levels) # E + self.extract_min_level = min(self.extract_levels) # D + + num_channels = [ + num_channel_initial * (2 ** level) + for level in range(self.extract_max_level + 1) + ] # level 0 to E + + self.downsample_blocks = nn.ModuleList([ + LocalNetDownSampleBlock( + spatial_dims=spatial_dims, + in_channels=in_channels if i == 0 else num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=7 if i == 0 else 3 + ) + for i in range(self.extract_max_level) + ]) # level 0 to E-1 + self.conv3d_block = get_conv_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-2], + out_channels=num_channels[-1] + ) # level E + + self.upsample_blocks = nn.ModuleList([ + LocalNetUpSampleBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level + 1], + out_channels=num_channels[level], + ) + for level in range( + self.extract_max_level - 1, self.extract_min_level - 1, -1 + ) + ]) # level D to E-1 + + self.extract_layers = nn.ModuleList([ + # if kernels are not initialized by zeros, with init NN, extract may be too large + ExtractBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level], + out_channels=out_channels, + kernel_initializer=out_kernel_initializer, + act=out_activation + ) + for level in self.extract_levels + ]) + + def forward(self, x): + image_size = x.shape[2:] + for size in image_size: + if size % (2 ** self.extract_max_level) != 0: + raise ValueError( + f"given extract_max_level {self.extract_max_level}, " + f"all input spatial dimension must be devidable by {2 ** self.extract_max_level}, " + f"got input of size {image_size}") + image_size = tuple(image_size) + encoded = [] + h_in = x + for level in range(self.extract_max_level): + h_in, h_channel = self.downsample_blocks[level](h_in) + encoded.append(h_channel) + h_bottom = self.conv3d_block(h_in) + + decoded = [h_bottom] + # level E-1 to D + for idx, level in enumerate(range(self.extract_max_level - 1, self.extract_min_level - 1, -1)): + h_bottom = self.upsample_blocks[idx](h_bottom, encoded[level]) + decoded.append(h_bottom) + + output = torch.mean( + torch.stack( + [ + F.interpolate( + self.extract_layers[idx](decoded[self.extract_max_level - level]), + size=image_size + ) + for idx, level in enumerate(self.extract_levels) + ], + dim=-1 + ), + dim=-1 + ) + return output + + +if __name__ == '__main__': + input = torch.rand((2, 2, 32, 32, 32)) + model = LocalNet( + spatial_dims=3, + in_channels=2, + out_channels=3, + num_channel_initial=32, + extract_levels=[0, 1, 2, 3], + out_kernel_initializer="zeros", + out_activation=None, + ) + out = model(input) + print(out.shape) + + + diff --git a/tests/test_localnet.py b/tests/test_localnet.py new file mode 100644 index 0000000000..47d26beadb --- /dev/null +++ b/tests/test_localnet.py @@ -0,0 +1,82 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.localnet import LocalNet +from tests.utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_LOCALNET_2D = [] +for out_kernel_initializer in ["zeros", None]: + for out_activation in [None]: + spatial_dims = 2 + in_channels = 2 + out_channels = 2 + num_channel_initial = 16 + extract_levels = [0, 1, 2] + in_size = 32 + input_shape = (1, in_channels, *[in_size] * spatial_dims) + expected_shape = (1, out_channels, *[in_size] * spatial_dims) + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_kernel_initializer": out_kernel_initializer, + "out_activation": out_activation + }, + input_shape, + expected_shape + ] + TEST_CASE_LOCALNET_2D.append(test_case) + + +TEST_CASE_LOCALNET_3D = [] +for in_channels in [2, 3]: + for out_channels in [1, 3]: + for num_channel_initial in [16, 32]: + for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: + for out_kernel_initializer in ["zeros", None]: + for out_activation in [None]: + spatial_dims = 3 + in_size = 32 + input_shape = (1, in_channels, *[in_size] * spatial_dims) + expected_shape = (1, out_channels, *[in_size] * spatial_dims) + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_kernel_initializer": out_kernel_initializer, + "out_activation": out_activation + }, + input_shape, + expected_shape + ] + TEST_CASE_LOCALNET_3D.append(test_case) + + +class TestDynUNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_LOCALNET_2D) + def test_shape(self, input_param, input_shape, expected_shape): + net = LocalNet(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0] + net = LocalNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 3455ee82db8aac996fcb797f1d27719adabe97b8 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 17:21:45 +0000 Subject: [PATCH 13/28] 1442 add test Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 159 +++++++++++------------- monai/networks/nets/localnet.py | 149 +++++++++++----------- tests/test_localnet.py | 91 ++++++-------- tests/test_localnet_block.py | 93 ++++++++++++++ 4 files changed, 276 insertions(+), 216 deletions(-) create mode 100644 tests/test_localnet_block.py diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 01edc45871..f0cc776163 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -1,27 +1,27 @@ -from typing import Union, Sequence, Optional, Tuple +from typing import Optional, Sequence, Tuple, Union import torch from torch import nn from torch.nn import functional as F -from monai.networks.blocks import Convolution, get_padding +from monai.networks.blocks import Convolution +from monai.networks.layers import same_padding from monai.networks.layers.factories import batch_factory, maxpooling_factory - initializer_dict = { "zeros": nn.init.zeros_, } def get_conv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - act: Optional[Union[Tuple, str]] = "RELU", - norm: Optional[Union[Tuple, str]] = "BATCH", + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + act: Optional[Union[Tuple, str]] = "RELU", + norm: Optional[Union[Tuple, str]] = "BATCH", ): - padding = get_padding(kernel_size, stride=1) + padding = same_padding(kernel_size) return Convolution( spatial_dims, in_channels, @@ -36,12 +36,12 @@ def get_conv_block( def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, ): - padding = get_padding(kernel_size, stride=1) + padding = same_padding(kernel_size) return Convolution( spatial_dims, in_channels, @@ -54,9 +54,9 @@ def get_conv_layer( def get_deconv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, + spatial_dims: int, + in_channels: int, + out_channels: int, ): return Convolution( dimensions=spatial_dims, @@ -74,17 +74,17 @@ def get_deconv_block( class ResidualBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], ): super(ResidualBlock, self).__init__() if in_channels != out_channels: raise ValueError( - f"expecting in_channels == out_channels, " - f"got in_channels={in_channels}, out_channels={out_channels}") + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, @@ -92,78 +92,62 @@ def __init__( kernel_size=kernel_size, ) self.conv = get_conv_layer( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - kernel_size=kernel_size + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) self.norm = batch_factory(spatial_dims)(out_channels) self.relu = nn.ReLU() def forward(self, x): - return self.relu( - self.norm( - self.conv( - self.conv_block(x) - ) - ) + x - ) + return self.relu(self.norm(self.conv(self.conv_block(x))) + x) class LocalNetResidualBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, + self, + spatial_dims: int, + in_channels: int, + out_channels: int, ): super(LocalNetResidualBlock, self).__init__() - # if in_channels != out_channels: - # raise ValueError( - # f"expecting in_channels == out_channels, " - # f"got in_channels={in_channels}, out_channels={out_channels}") + if in_channels != out_channels: + raise ValueError( + f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" + ) self.conv_layer = get_conv_layer( spatial_dims=spatial_dims, - in_channels=out_channels, + in_channels=in_channels, out_channels=out_channels, ) self.norm = batch_factory(spatial_dims)(out_channels) self.relu = nn.ReLU() def forward(self, x, mid): - return self.relu( - self.norm( - self.conv_layer(x) - ) + mid - ) + return self.relu(self.norm(self.conv_layer(x)) + mid) class LocalNetDownSampleBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], ): super(LocalNetDownSampleBlock, self).__init__() self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.residual_block = ResidualBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - kernel_size=kernel_size + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) self.max_pool = maxpooling_factory(spatial_dims)( kernel_size=2, ) def forward(self, x): + for i in x.shape[2:]: + if i % 2 != 0: + raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}") x = self.conv_block(x) mid = self.residual_block(x) x = self.max_pool(mid) @@ -172,10 +156,10 @@ def forward(self, x): class LocalNetUpSampleBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, + self, + spatial_dims: int, + in_channels: int, + out_channels: int, ): super(LocalNetUpSampleBlock, self).__init__() self.deconv_block = get_deconv_block( @@ -190,7 +174,7 @@ def __init__( ) self.residual_block = LocalNetResidualBlock( spatial_dims=spatial_dims, - in_channels=in_channels, + in_channels=out_channels, out_channels=out_channels, ) if in_channels / out_channels != 2: @@ -200,46 +184,43 @@ def __init__( ) self.out_channels = out_channels - def addictive_upsampling(self, x): - size = torch.Size(torch.tensor(x.shape[2:]) * 2) - x = F.interpolate(x, size) + def addictive_upsampling(self, x, mid): + x = F.interpolate(x, mid.shape[2:]) # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) # (batch, out_channels, ...) - x = torch.sum( - torch.stack(x, dim=-1), - dim=-1 - ) + x = torch.sum(torch.stack(x, dim=-1), dim=-1) return x def forward(self, x, mid): - h0 = self.deconv_block(x) + self.addictive_upsampling(x) + for i, j in zip(x.shape[2:], mid.shape[2:]): + if j != 2 * i: + raise ValueError( + "expecting mid spatial dimensions be exactly the double of x spatial dimensions, " + f"got x of shape {x.shape}, mid of shape {mid.shape}" + ) + h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) r1 = h0 + mid r2 = self.conv_block(h0) return self.residual_block(r2, r1) class ExtractBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - act: Optional[Union[Tuple, str]] = "RELU", - kernel_initializer: Optional[Union[Tuple, str]] = None, + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Optional[Union[Tuple, str]] = "RELU", + kernel_initializer: Optional[Union[Tuple, str]] = None, ): super(ExtractBlock, self).__init__() self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - act=act, - norm=None + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) if kernel_initializer: initializer_dict[kernel_initializer](self.conv_block.conv.weight) def forward(self, x): x = self.conv_block(x) - return x \ No newline at end of file + return x diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 0320151224..234e56cc91 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -1,26 +1,29 @@ -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F -from monai.networks.blocks.localnet_block import LocalNetDownSampleBlock, get_conv_block, LocalNetUpSampleBlock, \ - ExtractBlock +from monai.networks.blocks.localnet_block import ( + ExtractBlock, + LocalNetDownSampleBlock, + LocalNetUpSampleBlock, + get_conv_block, +) class LocalNet(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_channel_initial: int, - extract_levels: List[int], - out_kernel_initializer: str, - out_activation: Optional[Union[Tuple, str]], - control_points: (tuple, None) = None, - **kwargs, + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_channel_initial: int, + extract_levels: List[int], + out_kernel_initializer: str, + out_activation: Optional[Union[Tuple, str]], + control_points: (tuple, None) = None, + **kwargs, ): super(LocalNet, self).__init__() self.extract_levels = extract_levels @@ -28,47 +31,48 @@ def __init__( self.extract_min_level = min(self.extract_levels) # D num_channels = [ - num_channel_initial * (2 ** level) - for level in range(self.extract_max_level + 1) + num_channel_initial * (2 ** level) for level in range(self.extract_max_level + 1) ] # level 0 to E - self.downsample_blocks = nn.ModuleList([ - LocalNetDownSampleBlock( - spatial_dims=spatial_dims, - in_channels=in_channels if i == 0 else num_channels[i - 1], - out_channels=num_channels[i], - kernel_size=7 if i == 0 else 3 - ) - for i in range(self.extract_max_level) - ]) # level 0 to E-1 + self.downsample_blocks = nn.ModuleList( + [ + LocalNetDownSampleBlock( + spatial_dims=spatial_dims, + in_channels=in_channels if i == 0 else num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=7 if i == 0 else 3, + ) + for i in range(self.extract_max_level) + ] + ) # level 0 to self.extract_max_level - 1 self.conv3d_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=num_channels[-2], - out_channels=num_channels[-1] - ) # level E - - self.upsample_blocks = nn.ModuleList([ - LocalNetUpSampleBlock( - spatial_dims=spatial_dims, - in_channels=num_channels[level + 1], - out_channels=num_channels[level], - ) - for level in range( - self.extract_max_level - 1, self.extract_min_level - 1, -1 - ) - ]) # level D to E-1 - - self.extract_layers = nn.ModuleList([ - # if kernels are not initialized by zeros, with init NN, extract may be too large - ExtractBlock( - spatial_dims=spatial_dims, - in_channels=num_channels[level], - out_channels=out_channels, - kernel_initializer=out_kernel_initializer, - act=out_activation - ) - for level in self.extract_levels - ]) + spatial_dims=spatial_dims, in_channels=num_channels[-2], out_channels=num_channels[-1] + ) # self.extract_max_level + + self.upsample_blocks = nn.ModuleList( + [ + LocalNetUpSampleBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level + 1], + out_channels=num_channels[level], + ) + for level in range(self.extract_max_level - 1, self.extract_min_level - 1, -1) + ] + ) # self.extract_max_level - 1 to self.extract_min_level + + self.extract_layers = nn.ModuleList( + [ + # if kernels are not initialized by zeros, with init NN, extract may be too large + ExtractBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[level], + out_channels=out_channels, + kernel_initializer=out_kernel_initializer, + act=out_activation, + ) + for level in self.extract_levels + ] + ) def forward(self, x): image_size = x.shape[2:] @@ -77,38 +81,36 @@ def forward(self, x): raise ValueError( f"given extract_max_level {self.extract_max_level}, " f"all input spatial dimension must be devidable by {2 ** self.extract_max_level}, " - f"got input of size {image_size}") - image_size = tuple(image_size) - encoded = [] - h_in = x - for level in range(self.extract_max_level): - h_in, h_channel = self.downsample_blocks[level](h_in) - encoded.append(h_channel) - h_bottom = self.conv3d_block(h_in) - - decoded = [h_bottom] - # level E-1 to D - for idx, level in enumerate(range(self.extract_max_level - 1, self.extract_min_level - 1, -1)): - h_bottom = self.upsample_blocks[idx](h_bottom, encoded[level]) - decoded.append(h_bottom) + f"got input of size {image_size}" + ) + mid_features = [] # 0 -> self.extract_max_level - 1 + for downsample_block in self.downsample_blocks: + x, mid = downsample_block(x) + mid_features.append(mid) + x = self.conv3d_block(x) # self.extract_max_level + + decoded_features = [x] + for idx, upsample_block in enumerate(self.upsample_blocks): + x = upsample_block(x, mid_features[-idx - 1]) + decoded_features.append(x) # self.extract_max_level -> self.extract_min_level output = torch.mean( torch.stack( [ F.interpolate( - self.extract_layers[idx](decoded[self.extract_max_level - level]), - size=image_size + extract_layer(decoded_features[self.extract_max_level - self.extract_levels[idx]]), + size=image_size, ) - for idx, level in enumerate(self.extract_levels) + for idx, extract_layer in enumerate(self.extract_layers) ], - dim=-1 + dim=-1, ), - dim=-1 + dim=-1, ) return output -if __name__ == '__main__': +if __name__ == "__main__": input = torch.rand((2, 2, 32, 32, 32)) model = LocalNet( spatial_dims=3, @@ -121,6 +123,3 @@ def forward(self, x): ) out = model(input) print(out.shape) - - - diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 47d26beadb..c73abfc194 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -1,4 +1,5 @@ import unittest +from itertools import product import torch from parameterized import parameterized @@ -9,68 +10,54 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_LOCALNET_2D = [] -for out_kernel_initializer in ["zeros", None]: - for out_activation in [None]: - spatial_dims = 2 - in_channels = 2 - out_channels = 2 - num_channel_initial = 16 - extract_levels = [0, 1, 2] - in_size = 32 - input_shape = (1, in_channels, *[in_size] * spatial_dims) - expected_shape = (1, out_channels, *[in_size] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "num_channel_initial": num_channel_initial, - "extract_levels": extract_levels, - "out_kernel_initializer": out_kernel_initializer, - "out_activation": out_activation - }, - input_shape, - expected_shape - ] - TEST_CASE_LOCALNET_2D.append(test_case) +param_variations_2d = { + "spatial_dims": [2], + "in_channels": [2], + "out_channels": [2], + "num_channel_initial": [16], + "extract_levels": [[0, 1, 2]], + "out_kernel_initializer": ["zeros", None], + "out_activation": ["sigmoid", None], +} +TEST_CASE_LOCALNET_2D = [dict(zip(param_variations_2d, v)) for v in product(*param_variations_2d.values())] +TEST_CASE_LOCALNET_2D = [ + [input_param, (1, input_param["in_channels"], 16, 16), (1, input_param["out_channels"], 16, 16)] + for input_param in TEST_CASE_LOCALNET_2D +] -TEST_CASE_LOCALNET_3D = [] -for in_channels in [2, 3]: - for out_channels in [1, 3]: - for num_channel_initial in [16, 32]: - for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: - for out_kernel_initializer in ["zeros", None]: - for out_activation in [None]: - spatial_dims = 3 - in_size = 32 - input_shape = (1, in_channels, *[in_size] * spatial_dims) - expected_shape = (1, out_channels, *[in_size] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "num_channel_initial": num_channel_initial, - "extract_levels": extract_levels, - "out_kernel_initializer": out_kernel_initializer, - "out_activation": out_activation - }, - input_shape, - expected_shape - ] - TEST_CASE_LOCALNET_3D.append(test_case) + +param_variations_3d = { + "spatial_dims": [3], + "in_channels": [2, 3], + "out_channels": [1, 3], + "num_channel_initial": [16, 32], + "extract_levels": [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]], + "out_kernel_initializer": ["zeros", None], + "out_activation": ["sigmoid", None], +} +TEST_CASE_LOCALNET_3D = [dict(zip(param_variations_3d, v)) for v in product(*param_variations_3d.values())] +TEST_CASE_LOCALNET_3D = [ + [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] + for input_param in TEST_CASE_LOCALNET_3D +] class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_LOCALNET_2D) + @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = LocalNet(**input_param).to(device) with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_ill_shape(self): + with self.assertRaisesRegex(ValueError, ""): + input_param, _, _ = TEST_CASE_LOCALNET_2D[0] + input_shape = (1, input_param["in_channels"], 17, 17) + net = LocalNet(**input_param).to(device) + net.forward(torch.randn(input_shape).to(device)) + def test_script(self): input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0] net = LocalNet(**input_param) @@ -79,4 +66,4 @@ def test_script(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py new file mode 100644 index 0000000000..c66ebb2275 --- /dev/null +++ b/tests/test_localnet_block.py @@ -0,0 +1,93 @@ +import unittest +from itertools import product + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.localnet_block import ExtractBlock, LocalNetDownSampleBlock, LocalNetUpSampleBlock + +TEST_CASE_DOWN_SAMPLE = [ + [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] +] + +TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] + +extract_param_option = { + "spatial_dims": [2, 3], + "in_channels": [2], + "out_channels": [3], + "act": ["sigmoid", None], + "kernel_initializer": ["zeros", None], +} +TEST_CASE_EXTRACT = [dict(zip(extract_param_option, v)) for v in product(*extract_param_option.values())] +TEST_CASE_EXTRACT = [[i] for i in TEST_CASE_EXTRACT] + +in_size = 4 + + +class TestLocalNetDownSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expect_mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + expect_x_shape = (1, input_param["out_channels"], *([in_size / 2] * input_param["spatial_dims"])) + with eval_mode(net): + x, mid = net(torch.randn(input_shape)) + self.assertEqual(x.shape, expect_x_shape) + self.assertEqual(mid.shape, expect_mid_shape) + + def test_ill_arg(self): + # even kernel_size + with self.assertRaises(NotImplementedError): + LocalNetDownSampleBlock(spatial_dims=2, in_channels=2, out_channels=4, kernel_size=4) + + @parameterized.expand(TEST_CASE_DOWN_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetDownSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([5] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape)) + + +class TestLocalNetUpSampleBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size * 2] * input_param["spatial_dims"])) + expected_shape = mid_shape + with eval_mode(net): + result = net(torch.randn(input_shape), torch.randn(mid_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + # channel unmatch + with self.assertRaises(ValueError): + LocalNetUpSampleBlock(spatial_dims=2, in_channels=2, out_channels=2) + + @parameterized.expand(TEST_CASE_UP_SAMPLE) + def test_ill_shape(self, input_param): + net = LocalNetUpSampleBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + mid_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with self.assertRaises(ValueError): + with eval_mode(net): + net(torch.randn(input_shape), torch.randn(mid_shape)) + + +class TestExtractBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_EXTRACT) + def test_shape(self, input_param): + net = ExtractBlock(**input_param) + input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) + expected_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 16011d7c564c73ea1d570a50b2219199c88183df Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 18:32:35 +0000 Subject: [PATCH 14/28] 1442 add documentation Signed-off-by: kate-sann5100 --- docs/source/networks.rst | 19 ++++++ monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/localnet_block.py | 85 ++++++++++++++++++++++++- monai/networks/nets/__init__.py | 1 + monai/networks/nets/localnet.py | 23 ++++++- tests/test_localnet_block.py | 8 ++- 6 files changed, 130 insertions(+), 7 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index ed17d815b4..420da311d2 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -119,6 +119,20 @@ Blocks .. autoclass:: Subpixelupsample .. autoclass:: SubpixelUpSample +`LocalNet DownSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetDownSampleBlock + :members: + +`LocalNet UpSample Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetUpSampleBlock + :members: + +`LocalNet Feature Extractor Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNetFeatureExtractorBlock + :members: Layers @@ -298,6 +312,11 @@ Nets .. autoclass:: VNet :members: +`LocalNet` +~~~~~~~~~~~ +.. autoclass:: LocalNet + :members: + `AutoEncoder` ~~~~~~~~~~~~~ .. autoclass:: AutoEncoder diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 10b13f619c..c33feb4e2b 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -16,6 +16,7 @@ from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine +from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .segresnet_block import ResBlock from .squeeze_and_excitation import ( ChannelSELayer, diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index f0cc776163..5d09b5a8dc 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -126,6 +126,15 @@ def forward(self, x, mid): class LocalNetDownSampleBlock(nn.Module): + """ + A down-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + def __init__( self, spatial_dims: int, @@ -133,6 +142,15 @@ def __init__( out_channels: int, kernel_size: Union[Sequence[int], int], ): + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + Raises: + NotImplementedError: when ``kernel_size`` is even + """ super(LocalNetDownSampleBlock, self).__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size @@ -144,7 +162,20 @@ def __init__( kernel_size=2, ) - def forward(self, x): + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Halves the spatial dimensions. + A tuple of (x, mid) is returned: + + - x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]), + - mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]) + + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + + Raises: + ValueError: when input spatial dimensions are not even. + """ for i in x.shape[2:]: if i % 2 != 0: raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}") @@ -155,12 +186,29 @@ def forward(self, x): class LocalNetUpSampleBlock(nn.Module): + """ + A up-sample module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, ): + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + Raises: + ValueError: when ``in_channels != 2 * out_channels`` + """ super(LocalNetUpSampleBlock, self).__init__() self.deconv_block = get_deconv_block( spatial_dims=spatial_dims, @@ -193,6 +241,16 @@ def addictive_upsampling(self, x, mid): return x def forward(self, x, mid): + """ + Halves the channel and doubles the spatial dimensions. + + Args: + x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + mid: mid-level feature saved during down-sampling, in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) + + Raises: + ValueError: when ``midsize != insize * 2`` + """ for i, j in zip(x.shape[2:], mid.shape[2:]): if j != 2 * i: raise ValueError( @@ -205,7 +263,16 @@ def forward(self, x, mid): return self.residual_block(r2, r1) -class ExtractBlock(nn.Module): +class LocalNetFeatureExtractorBlock(nn.Module): + """ + A feature-extraction module that can be used for LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + def __init__( self, spatial_dims: int, @@ -214,7 +281,15 @@ def __init__( act: Optional[Union[Tuple, str]] = "RELU", kernel_initializer: Optional[Union[Tuple, str]] = None, ): - super(ExtractBlock, self).__init__() + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + act: activation type and arguments. Defaults to ReLU. + kernel_initializer: kernel initializer. Defaults to None. + """ + super(LocalNetFeatureExtractorBlock, self).__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) @@ -222,5 +297,9 @@ def __init__( initializer_dict[kernel_initializer](self.conv_block.conv.weight) def forward(self, x): + """ + Args: + x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) + """ x = self.conv_block(x) return x diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 6c7570ebf9..a9308de9d7 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .localnet import LocalNet from .regressor import Regressor from .segresnet import SegResNet, SegResNetVAE from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 234e56cc91..8e347eb2b8 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -5,14 +5,23 @@ from torch.nn import functional as F from monai.networks.blocks.localnet_block import ( - ExtractBlock, LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock, get_conv_block, ) class LocalNet(nn.Module): + """ + Reimplementation of LocalNet, based on: + `Weakly-supervised convolutional neural networks for multimodal image registration `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + def __init__( self, spatial_dims: int, @@ -25,6 +34,16 @@ def __init__( control_points: (tuple, None) = None, **kwargs, ): + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_channel_initial: number of initial channels, + extract_levels: number of extraction levels, + out_kernel_initializer: initializer to use for kernels, + out_activation: activation to use at end layer, + """ super(LocalNet, self).__init__() self.extract_levels = extract_levels self.extract_max_level = max(self.extract_levels) # E @@ -63,7 +82,7 @@ def __init__( self.extract_layers = nn.ModuleList( [ # if kernels are not initialized by zeros, with init NN, extract may be too large - ExtractBlock( + LocalNetFeatureExtractorBlock( spatial_dims=spatial_dims, in_channels=num_channels[level], out_channels=out_channels, diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index c66ebb2275..b7fea3b694 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -5,7 +5,11 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks.localnet_block import ExtractBlock, LocalNetDownSampleBlock, LocalNetUpSampleBlock +from monai.networks.blocks.localnet_block import ( + LocalNetDownSampleBlock, + LocalNetFeatureExtractorBlock, + LocalNetUpSampleBlock, +) TEST_CASE_DOWN_SAMPLE = [ [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] @@ -81,7 +85,7 @@ def test_ill_shape(self, input_param): class TestExtractBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_EXTRACT) def test_shape(self, input_param): - net = ExtractBlock(**input_param) + net = LocalNetFeatureExtractorBlock(**input_param) input_shape = (1, input_param["in_channels"], *([in_size] * input_param["spatial_dims"])) expected_shape = (1, input_param["out_channels"], *([in_size] * input_param["spatial_dims"])) with eval_mode(net): From a99782b1d6de26d7c5aa9655f6b58a2801ee00cb Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 18:37:06 +0000 Subject: [PATCH 15/28] 1442 add typing Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 26 ++++++++++++------------- monai/networks/nets/localnet.py | 21 ++------------------ 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 5d09b5a8dc..a219661fe3 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -20,7 +20,7 @@ def get_conv_block( kernel_size: Union[Sequence[int], int] = 3, act: Optional[Union[Tuple, str]] = "RELU", norm: Optional[Union[Tuple, str]] = "BATCH", -): +) -> nn.Module: padding = same_padding(kernel_size) return Convolution( spatial_dims, @@ -40,7 +40,7 @@ def get_conv_layer( in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3, -): +) -> nn.Module: padding = same_padding(kernel_size) return Convolution( spatial_dims, @@ -57,7 +57,7 @@ def get_deconv_block( spatial_dims: int, in_channels: int, out_channels: int, -): +) -> nn.Module: return Convolution( dimensions=spatial_dims, in_channels=in_channels, @@ -79,7 +79,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], - ): + ) -> None: super(ResidualBlock, self).__init__() if in_channels != out_channels: raise ValueError( @@ -97,7 +97,7 @@ def __init__( self.norm = batch_factory(spatial_dims)(out_channels) self.relu = nn.ReLU() - def forward(self, x): + def forward(self, x) -> torch.Tensor: return self.relu(self.norm(self.conv(self.conv_block(x))) + x) @@ -107,7 +107,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - ): + ) -> None: super(LocalNetResidualBlock, self).__init__() if in_channels != out_channels: raise ValueError( @@ -121,7 +121,7 @@ def __init__( self.norm = batch_factory(spatial_dims)(out_channels) self.relu = nn.ReLU() - def forward(self, x, mid): + def forward(self, x, mid) -> torch.Tensor: return self.relu(self.norm(self.conv_layer(x)) + mid) @@ -141,7 +141,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], - ): + ) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -200,7 +200,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - ): + ) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -232,7 +232,7 @@ def __init__( ) self.out_channels = out_channels - def addictive_upsampling(self, x, mid): + def addictive_upsampling(self, x, mid) -> torch.Tensor: x = F.interpolate(x, mid.shape[2:]) # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) @@ -240,7 +240,7 @@ def addictive_upsampling(self, x, mid): x = torch.sum(torch.stack(x, dim=-1), dim=-1) return x - def forward(self, x, mid): + def forward(self, x, mid) -> torch.Tensor: """ Halves the channel and doubles the spatial dimensions. @@ -280,7 +280,7 @@ def __init__( out_channels: int, act: Optional[Union[Tuple, str]] = "RELU", kernel_initializer: Optional[Union[Tuple, str]] = None, - ): + ) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -296,7 +296,7 @@ def __init__( if kernel_initializer: initializer_dict[kernel_initializer](self.conv_block.conv.weight) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """ Args: x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 8e347eb2b8..4f7826fb11 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -31,9 +31,7 @@ def __init__( extract_levels: List[int], out_kernel_initializer: str, out_activation: Optional[Union[Tuple, str]], - control_points: (tuple, None) = None, - **kwargs, - ): + ) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -93,7 +91,7 @@ def __init__( ] ) - def forward(self, x): + def forward(self, x) -> torch.Tensor: image_size = x.shape[2:] for size in image_size: if size % (2 ** self.extract_max_level) != 0: @@ -127,18 +125,3 @@ def forward(self, x): dim=-1, ) return output - - -if __name__ == "__main__": - input = torch.rand((2, 2, 32, 32, 32)) - model = LocalNet( - spatial_dims=3, - in_channels=2, - out_channels=3, - num_channel_initial=32, - extract_levels=[0, 1, 2, 3], - out_kernel_initializer="zeros", - out_activation=None, - ) - out = model(input) - print(out.shape) From 37724140e7ee4e40e127d00c2cf6a2a7073bb770 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 18:55:06 +0000 Subject: [PATCH 16/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index a219661fe3..2a86d69997 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -128,8 +128,10 @@ def forward(self, x, mid) -> torch.Tensor: class LocalNetDownSampleBlock(nn.Module): """ A down-sample module that can be used for LocalNet, based on: - `Weakly-supervised convolutional neural networks for multimodal image registration `_. - `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -188,8 +190,10 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class LocalNetUpSampleBlock(nn.Module): """ A up-sample module that can be used for LocalNet, based on: - `Weakly-supervised convolutional neural networks for multimodal image registration `_. - `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -246,7 +250,8 @@ def forward(self, x, mid) -> torch.Tensor: Args: x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) - mid: mid-level feature saved during down-sampling, in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) + mid: mid-level feature saved during down-sampling, + in shape (batch, ``out_channels``, midsize_1, midsize_2, [midnsize_3]) Raises: ValueError: when ``midsize != insize * 2`` @@ -266,8 +271,10 @@ def forward(self, x, mid) -> torch.Tensor: class LocalNetFeatureExtractorBlock(nn.Module): """ A feature-extraction module that can be used for LocalNet, based on: - `Weakly-supervised convolutional neural networks for multimodal image registration `_. - `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) From c2373e99fbe89773157d2a637284f82fbcbce80f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 19:38:35 +0000 Subject: [PATCH 17/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/nets/localnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 4f7826fb11..28573910eb 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -15,8 +15,10 @@ class LocalNet(nn.Module): """ Reimplementation of LocalNet, based on: - `Weakly-supervised convolutional neural networks for multimodal image registration `_. - `Label-driven weakly-supervised learning for multimodal deformable image registration `_. + `Weakly-supervised convolutional neural networks for multimodal image registration + `_. + `Label-driven weakly-supervised learning for multimodal deformable image registration + `_. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) From 99b3fac16f9e53c884bd5646267de7c92a61fcd7 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 21:16:22 +0000 Subject: [PATCH 18/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 20 ++++++++++---------- tests/test_localnet.py | 8 ++++---- tests/test_localnet_block.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 2a86d69997..84f47bbf58 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -8,9 +8,7 @@ from monai.networks.layers import same_padding from monai.networks.layers.factories import batch_factory, maxpooling_factory -initializer_dict = { - "zeros": nn.init.zeros_, -} +initializer_dict = {"zeros": nn.init.zeros_, "kaiming_normal": nn.init.kaiming_normal_} def get_conv_block( @@ -98,7 +96,8 @@ def __init__( self.relu = nn.ReLU() def forward(self, x) -> torch.Tensor: - return self.relu(self.norm(self.conv(self.conv_block(x))) + x) + x: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x) + return x class LocalNetResidualBlock(nn.Module): @@ -122,7 +121,8 @@ def __init__( self.relu = nn.ReLU() def forward(self, x, mid) -> torch.Tensor: - return self.relu(self.norm(self.conv_layer(x)) + mid) + x = self.relu(self.norm(self.conv_layer(x)) + mid) + return x class LocalNetDownSampleBlock(nn.Module): @@ -241,7 +241,7 @@ def addictive_upsampling(self, x, mid) -> torch.Tensor: # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) # (batch, out_channels, ...) - x = torch.sum(torch.stack(x, dim=-1), dim=-1) + x: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1) return x def forward(self, x, mid) -> torch.Tensor: @@ -265,7 +265,8 @@ def forward(self, x, mid) -> torch.Tensor: h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) r1 = h0 + mid r2 = self.conv_block(h0) - return self.residual_block(r2, r1) + x = self.residual_block(r2, r1) + return x class LocalNetFeatureExtractorBlock(nn.Module): @@ -286,7 +287,7 @@ def __init__( in_channels: int, out_channels: int, act: Optional[Union[Tuple, str]] = "RELU", - kernel_initializer: Optional[Union[Tuple, str]] = None, + kernel_initializer: str = "kaiming_normal", ) -> None: """ Args: @@ -300,8 +301,7 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - if kernel_initializer: - initializer_dict[kernel_initializer](self.conv_block.conv.weight) + initializer_dict[kernel_initializer](self.conv_block.conv.weight) def forward(self, x) -> torch.Tensor: """ diff --git a/tests/test_localnet.py b/tests/test_localnet.py index c73abfc194..70ef1c0cc2 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -17,10 +17,10 @@ "out_channels": [2], "num_channel_initial": [16], "extract_levels": [[0, 1, 2]], - "out_kernel_initializer": ["zeros", None], + "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_2D = [dict(zip(param_variations_2d, v)) for v in product(*param_variations_2d.values())] +TEST_CASE_LOCALNET_2D = [dict(zip(param_variations_2d.keys(), v)) for v in product(*param_variations_2d.values())] TEST_CASE_LOCALNET_2D = [ [input_param, (1, input_param["in_channels"], 16, 16), (1, input_param["out_channels"], 16, 16)] for input_param in TEST_CASE_LOCALNET_2D @@ -33,10 +33,10 @@ "out_channels": [1, 3], "num_channel_initial": [16, 32], "extract_levels": [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]], - "out_kernel_initializer": ["zeros", None], + "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_3D = [dict(zip(param_variations_3d, v)) for v in product(*param_variations_3d.values())] +TEST_CASE_LOCALNET_3D = [dict(zip(param_variations_3d.keys(), v)) for v in product(*param_variations_3d.values())] TEST_CASE_LOCALNET_3D = [ [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] for input_param in TEST_CASE_LOCALNET_3D diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index b7fea3b694..4ea97b0c20 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -22,9 +22,9 @@ "in_channels": [2], "out_channels": [3], "act": ["sigmoid", None], - "kernel_initializer": ["zeros", None], + "kernel_initializer": ["zeros", "kaiming_normal"], } -TEST_CASE_EXTRACT = [dict(zip(extract_param_option, v)) for v in product(*extract_param_option.values())] +TEST_CASE_EXTRACT = [dict(zip(extract_param_option.keys(), v)) for v in product(*extract_param_option.values())] TEST_CASE_EXTRACT = [[i] for i in TEST_CASE_EXTRACT] in_size = 4 From 9437fb30723f474fdf8cc667fd30731413a108a1 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 21:46:15 +0000 Subject: [PATCH 19/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 23 ++++++++++++----------- tests/test_localnet.py | 5 +++-- tests/test_localnet_block.py | 3 ++- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 84f47bbf58..e1cffb3358 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -96,8 +96,8 @@ def __init__( self.relu = nn.ReLU() def forward(self, x) -> torch.Tensor: - x: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x) - return x + out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x) + return out class LocalNetResidualBlock(nn.Module): @@ -121,8 +121,8 @@ def __init__( self.relu = nn.ReLU() def forward(self, x, mid) -> torch.Tensor: - x = self.relu(self.norm(self.conv_layer(x)) + mid) - return x + out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid) + return out class LocalNetDownSampleBlock(nn.Module): @@ -241,8 +241,8 @@ def addictive_upsampling(self, x, mid) -> torch.Tensor: # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) # (batch, out_channels, ...) - x: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1) - return x + out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1) + return out def forward(self, x, mid) -> torch.Tensor: """ @@ -265,8 +265,8 @@ def forward(self, x, mid) -> torch.Tensor: h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) r1 = h0 + mid r2 = self.conv_block(h0) - x = self.residual_block(r2, r1) - return x + out: torch.Tensor = self.residual_block(r2, r1) + return out class LocalNetFeatureExtractorBlock(nn.Module): @@ -301,12 +301,13 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - initializer_dict[kernel_initializer](self.conv_block.conv.weight) + initialize = initializer_dict[kernel_initializer] + initialize(self.conv_block.weight) def forward(self, x) -> torch.Tensor: """ Args: x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) """ - x = self.conv_block(x) - return x + out: torch.Tensor = self.conv_block(x) + return out diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 70ef1c0cc2..182d41b519 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -20,7 +20,7 @@ "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_2D = [dict(zip(param_variations_2d.keys(), v)) for v in product(*param_variations_2d.values())] +TEST_CASE_LOCALNET_2D = [dict(zip(list(param_variations_2d.keys()), v)) for v in product(*list(param_variations_2d.values()))] TEST_CASE_LOCALNET_2D = [ [input_param, (1, input_param["in_channels"], 16, 16), (1, input_param["out_channels"], 16, 16)] for input_param in TEST_CASE_LOCALNET_2D @@ -36,7 +36,8 @@ "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_3D = [dict(zip(param_variations_3d.keys(), v)) for v in product(*param_variations_3d.values())] +TEST_CASE_LOCALNET_3D = [dict(zip(list(param_variations_3d.keys()), v)) + for v in product(*list(param_variations_3d.values()))] TEST_CASE_LOCALNET_3D = [ [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] for input_param in TEST_CASE_LOCALNET_3D diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 4ea97b0c20..9e0dda05db 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -24,7 +24,8 @@ "act": ["sigmoid", None], "kernel_initializer": ["zeros", "kaiming_normal"], } -TEST_CASE_EXTRACT = [dict(zip(extract_param_option.keys(), v)) for v in product(*extract_param_option.values())] +TEST_CASE_EXTRACT = [dict(zip(list(extract_param_option.keys()), v)) + for v in product(*list(extract_param_option.values()))] TEST_CASE_EXTRACT = [[i] for i in TEST_CASE_EXTRACT] in_size = 4 From 7ef82c57c5f96e29683b7ad60f0c1502216dd1ed Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 13 Jan 2021 21:50:43 +0000 Subject: [PATCH 20/28] 1442 reformat Signed-off-by: kate-sann5100 --- tests/test_localnet.py | 9 ++++++--- tests/test_localnet_block.py | 5 +++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 182d41b519..4e0342d588 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -20,7 +20,9 @@ "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_2D = [dict(zip(list(param_variations_2d.keys()), v)) for v in product(*list(param_variations_2d.values()))] +TEST_CASE_LOCALNET_2D = [ + dict(zip(list(param_variations_2d.keys()), v)) for v in product(*list(param_variations_2d.values())) +] TEST_CASE_LOCALNET_2D = [ [input_param, (1, input_param["in_channels"], 16, 16), (1, input_param["out_channels"], 16, 16)] for input_param in TEST_CASE_LOCALNET_2D @@ -36,8 +38,9 @@ "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } -TEST_CASE_LOCALNET_3D = [dict(zip(list(param_variations_3d.keys()), v)) - for v in product(*list(param_variations_3d.values()))] +TEST_CASE_LOCALNET_3D = [ + dict(zip(list(param_variations_3d.keys()), v)) for v in product(*list(param_variations_3d.values())) +] TEST_CASE_LOCALNET_3D = [ [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] for input_param in TEST_CASE_LOCALNET_3D diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 9e0dda05db..5f465ca5d4 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -24,8 +24,9 @@ "act": ["sigmoid", None], "kernel_initializer": ["zeros", "kaiming_normal"], } -TEST_CASE_EXTRACT = [dict(zip(list(extract_param_option.keys()), v)) - for v in product(*list(extract_param_option.values()))] +TEST_CASE_EXTRACT = [ + dict(zip(list(extract_param_option.keys()), v)) for v in product(*list(extract_param_option.values())) +] TEST_CASE_EXTRACT = [[i] for i in TEST_CASE_EXTRACT] in_size = 4 From c5108069b5d1d9db0f76b2c6348fb8d688fb1d09 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 00:29:50 +0000 Subject: [PATCH 21/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 11 +++++++-- tests/test_localnet.py | 33 ++++++++++++++++--------- tests/test_localnet_block.py | 20 ++++++++------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index e1cffb3358..784e9d8d98 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -301,8 +301,15 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - initialize = initializer_dict[kernel_initializer] - initialize(self.conv_block.weight) + + if kernel_initializer == "zeros": + nn.init.zeros_(self.conv_block.conv.weight) + elif kernel_initializer == "kaiming_normal": + nn.init.kaiming_normal_(self.conv_block.conv.weight) + else: + raise ValueError( + f"unsupported kernel_initializer: {kernel_initializer}, currently supporting [zero, kaiming_normal]" + ) def forward(self, x) -> torch.Tensor: """ diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 4e0342d588..5191a71625 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -12,20 +12,30 @@ param_variations_2d = { - "spatial_dims": [2], - "in_channels": [2], - "out_channels": [2], - "num_channel_initial": [16], - "extract_levels": [[0, 1, 2]], + "spatial_dims": 2, + "in_channels": 2, + "out_channels": 2, + "num_channel_initial": 16, + "extract_levels": [0, 1, 2], "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } + TEST_CASE_LOCALNET_2D = [ - dict(zip(list(param_variations_2d.keys()), v)) for v in product(*list(param_variations_2d.values())) -] -TEST_CASE_LOCALNET_2D = [ - [input_param, (1, input_param["in_channels"], 16, 16), (1, input_param["out_channels"], 16, 16)] - for input_param in TEST_CASE_LOCALNET_2D + [ + { + "spatial_dims": 2, + "in_channels": 2, + "out_channels": 2, + "num_channel_initial": 16, + "extract_levels": [0, 1, 2], + "out_kernel_initializer": init, + "out_activation": act, + }, + (1, 2, 16, 16), + (1, 2, 16, 16), + ] + for act, init in zip(["sigmoid", None], ["zeros", "kaiming_normal"]) ] @@ -39,7 +49,8 @@ "out_activation": ["sigmoid", None], } TEST_CASE_LOCALNET_3D = [ - dict(zip(list(param_variations_3d.keys()), v)) for v in product(*list(param_variations_3d.values())) + {k: v for k, v in zip(param_variations_3d.keys(), per)} + for per in product(*[iter(v) for v in param_variations_3d.values()]) ] TEST_CASE_LOCALNET_3D = [ [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 5f465ca5d4..8b84edafdb 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -1,5 +1,6 @@ import unittest from itertools import product +from typing import Iterable import torch from parameterized import parameterized @@ -17,17 +18,18 @@ TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] -extract_param_option = { - "spatial_dims": [2, 3], - "in_channels": [2], - "out_channels": [3], - "act": ["sigmoid", None], - "kernel_initializer": ["zeros", "kaiming_normal"], -} TEST_CASE_EXTRACT = [ - dict(zip(list(extract_param_option.keys()), v)) for v in product(*list(extract_param_option.values())) + [ + { + "spatial_dims": spatial_dims, + "in_channels": 2, + "out_channels": 3, + "act": act, + "kernel_initializer": kernel_initializer, + } + ] + for spatial_dims, act, kernel_initializer in zip([2, 3], ["sigmoid", None], ["zeros", "kaiming_normal"]) ] -TEST_CASE_EXTRACT = [[i] for i in TEST_CASE_EXTRACT] in_size = 4 From 61691e534b4a5f3d09b9df1f3548bc9075be9bc6 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 00:35:41 +0000 Subject: [PATCH 22/28] 1442 reformat Signed-off-by: kate-sann5100 --- tests/test_localnet_block.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 8b84edafdb..a07af0685f 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -1,6 +1,4 @@ import unittest -from itertools import product -from typing import Iterable import torch from parameterized import parameterized From 1399f680b7404e94a8e813abf4e76a46103ec613 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 15:20:42 +0000 Subject: [PATCH 23/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/convolutions.py | 4 ++-- tests/test_localnet.py | 32 +++++++++++++++------------ tests/test_localnet_block.py | 5 +++++ 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 5e2dcf163c..a6fd9e017a 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -97,7 +97,7 @@ def __init__( if is_transposed: if output_padding is None: output_padding = stride_minus_kernel_padding(1, strides) - conv = conv_type( + conv: nn.Module = conv_type( in_channels, out_channels, kernel_size=kernel_size, @@ -109,7 +109,7 @@ def __init__( dilation=dilation, ) else: - conv = conv_type( + conv: nn.Module = conv_type( in_channels, out_channels, kernel_size=kernel_size, diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 5191a71625..808dcc8678 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -38,20 +38,24 @@ for act, init in zip(["sigmoid", None], ["zeros", "kaiming_normal"]) ] - -param_variations_3d = { - "spatial_dims": [3], - "in_channels": [2, 3], - "out_channels": [1, 3], - "num_channel_initial": [16, 32], - "extract_levels": [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]], - "out_kernel_initializer": ["zeros", "kaiming_normal"], - "out_activation": ["sigmoid", None], -} -TEST_CASE_LOCALNET_3D = [ - {k: v for k, v in zip(param_variations_3d.keys(), per)} - for per in product(*[iter(v) for v in param_variations_3d.values()]) -] +TEST_CASE_LOCALNET_3D = [] +for in_channels in [2, 3]: + for out_channels in [1, 3]: + for num_channel_initial in [4, 16, 32]: + for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: + for out_kernel_initializer in ["zeros", "kaiming_normal"]: + for out_activation in ["sigmoid", None]: + TEST_CASE_LOCALNET_3D.append( + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_kernel_initializer": out_kernel_initializer, + "out_activation": out_activation, + } + ) TEST_CASE_LOCALNET_3D = [ [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] for input_param in TEST_CASE_LOCALNET_3D diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index a07af0685f..6407a7aae6 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -94,6 +94,11 @@ def test_shape(self, input_param): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_ill_arg(self): + # channel unmatch + with self.assertRaises(ValueError): + LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, kernel_initializer="none") + if __name__ == "__main__": unittest.main() From 63290352c6358214fc6432b06df19402f570bacf Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 15:26:21 +0000 Subject: [PATCH 24/28] 1442 reformat Signed-off-by: kate-sann5100 --- tests/test_localnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 808dcc8678..1854ed868b 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -1,5 +1,4 @@ import unittest -from itertools import product import torch from parameterized import parameterized From 6e8b60e88ae31bae465435aeb3c3d59c631eb18f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 15:41:16 +0000 Subject: [PATCH 25/28] 1442 reformat Signed-off-by: kate-sann5100 --- monai/networks/blocks/convolutions.py | 5 +++-- tests/test_localnet.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index a6fd9e017a..7bfb3b47e4 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -94,10 +94,11 @@ def __init__( padding = same_padding(kernel_size, dilation) conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + conv: nn.Module if is_transposed: if output_padding is None: output_padding = stride_minus_kernel_padding(1, strides) - conv: nn.Module = conv_type( + conv = conv_type( in_channels, out_channels, kernel_size=kernel_size, @@ -109,7 +110,7 @@ def __init__( dilation=dilation, ) else: - conv: nn.Module = conv_type( + conv = conv_type( in_channels, out_channels, kernel_size=kernel_size, diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 1854ed868b..eb00a8f198 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -45,20 +45,20 @@ for out_kernel_initializer in ["zeros", "kaiming_normal"]: for out_activation in ["sigmoid", None]: TEST_CASE_LOCALNET_3D.append( - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "num_channel_initial": num_channel_initial, - "extract_levels": extract_levels, - "out_kernel_initializer": out_kernel_initializer, - "out_activation": out_activation, - } + [ + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_kernel_initializer": out_kernel_initializer, + "out_activation": out_activation, + }, + (1, in_channels, 16, 16, 16), + (1, out_channels, 16, 16, 16) + ] ) -TEST_CASE_LOCALNET_3D = [ - [input_param, (1, input_param["in_channels"], 16, 16, 16), (1, input_param["out_channels"], 16, 16, 16)] - for input_param in TEST_CASE_LOCALNET_3D -] class TestDynUNet(unittest.TestCase): From c51d597eea7a414e627616654825f10563a6e361 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 15:48:20 +0000 Subject: [PATCH 26/28] 1442 reformat Signed-off-by: kate-sann5100 --- tests/test_localnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_localnet.py b/tests/test_localnet.py index eb00a8f198..ab7c80ad99 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -56,7 +56,7 @@ "out_activation": out_activation, }, (1, in_channels, 16, 16, 16), - (1, out_channels, 16, 16, 16) + (1, out_channels, 16, 16, 16), ] ) From 2f18619a2f27b448d7325034e28376b9eeff204e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 14 Jan 2021 16:51:21 +0000 Subject: [PATCH 27/28] 1442 remove initializsation Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 12 --------- monai/networks/nets/localnet.py | 3 --- tests/test_localnet.py | 36 +++++++++++-------------- tests/test_localnet_block.py | 8 +----- 4 files changed, 17 insertions(+), 42 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 784e9d8d98..17d93098a1 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -8,8 +8,6 @@ from monai.networks.layers import same_padding from monai.networks.layers.factories import batch_factory, maxpooling_factory -initializer_dict = {"zeros": nn.init.zeros_, "kaiming_normal": nn.init.kaiming_normal_} - def get_conv_block( spatial_dims: int, @@ -287,7 +285,6 @@ def __init__( in_channels: int, out_channels: int, act: Optional[Union[Tuple, str]] = "RELU", - kernel_initializer: str = "kaiming_normal", ) -> None: """ Args: @@ -302,15 +299,6 @@ def __init__( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - if kernel_initializer == "zeros": - nn.init.zeros_(self.conv_block.conv.weight) - elif kernel_initializer == "kaiming_normal": - nn.init.kaiming_normal_(self.conv_block.conv.weight) - else: - raise ValueError( - f"unsupported kernel_initializer: {kernel_initializer}, currently supporting [zero, kaiming_normal]" - ) - def forward(self, x) -> torch.Tensor: """ Args: diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 28573910eb..1bb3dcbc21 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -31,7 +31,6 @@ def __init__( out_channels: int, num_channel_initial: int, extract_levels: List[int], - out_kernel_initializer: str, out_activation: Optional[Union[Tuple, str]], ) -> None: """ @@ -41,7 +40,6 @@ def __init__( out_channels: number of output channels. num_channel_initial: number of initial channels, extract_levels: number of extraction levels, - out_kernel_initializer: initializer to use for kernels, out_activation: activation to use at end layer, """ super(LocalNet, self).__init__() @@ -86,7 +84,6 @@ def __init__( spatial_dims=spatial_dims, in_channels=num_channels[level], out_channels=out_channels, - kernel_initializer=out_kernel_initializer, act=out_activation, ) for level in self.extract_levels diff --git a/tests/test_localnet.py b/tests/test_localnet.py index ab7c80ad99..d4f812e811 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -16,7 +16,6 @@ "out_channels": 2, "num_channel_initial": 16, "extract_levels": [0, 1, 2], - "out_kernel_initializer": ["zeros", "kaiming_normal"], "out_activation": ["sigmoid", None], } @@ -28,13 +27,12 @@ "out_channels": 2, "num_channel_initial": 16, "extract_levels": [0, 1, 2], - "out_kernel_initializer": init, "out_activation": act, }, (1, 2, 16, 16), (1, 2, 16, 16), ] - for act, init in zip(["sigmoid", None], ["zeros", "kaiming_normal"]) + for act in ["sigmoid", None] ] TEST_CASE_LOCALNET_3D = [] @@ -42,23 +40,21 @@ for out_channels in [1, 3]: for num_channel_initial in [4, 16, 32]: for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: - for out_kernel_initializer in ["zeros", "kaiming_normal"]: - for out_activation in ["sigmoid", None]: - TEST_CASE_LOCALNET_3D.append( - [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "num_channel_initial": num_channel_initial, - "extract_levels": extract_levels, - "out_kernel_initializer": out_kernel_initializer, - "out_activation": out_activation, - }, - (1, in_channels, 16, 16, 16), - (1, out_channels, 16, 16, 16), - ] - ) + for out_activation in ["sigmoid", None]: + TEST_CASE_LOCALNET_3D.append( + [ + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_activation": out_activation, + }, + (1, in_channels, 16, 16, 16), + (1, out_channels, 16, 16, 16), + ] + ) class TestDynUNet(unittest.TestCase): diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index 6407a7aae6..af5ef19222 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -23,10 +23,9 @@ "in_channels": 2, "out_channels": 3, "act": act, - "kernel_initializer": kernel_initializer, } ] - for spatial_dims, act, kernel_initializer in zip([2, 3], ["sigmoid", None], ["zeros", "kaiming_normal"]) + for spatial_dims, act in zip([2, 3], ["sigmoid", None]) ] in_size = 4 @@ -94,11 +93,6 @@ def test_shape(self, input_param): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_ill_arg(self): - # channel unmatch - with self.assertRaises(ValueError): - LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, kernel_initializer="none") - if __name__ == "__main__": unittest.main() From 786eb075b31ef8e9edde2e66816d49aa9695f724 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Fri, 15 Jan 2021 00:04:43 +0000 Subject: [PATCH 28/28] 1442 update factory calls Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 17d93098a1..ee7fac0690 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -6,7 +6,7 @@ from monai.networks.blocks import Convolution from monai.networks.layers import same_padding -from monai.networks.layers.factories import batch_factory, maxpooling_factory +from monai.networks.layers.factories import Norm, Pool def get_conv_block( @@ -90,7 +90,7 @@ def __init__( self.conv = get_conv_layer( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) - self.norm = batch_factory(spatial_dims)(out_channels) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) self.relu = nn.ReLU() def forward(self, x) -> torch.Tensor: @@ -115,7 +115,7 @@ def __init__( in_channels=in_channels, out_channels=out_channels, ) - self.norm = batch_factory(spatial_dims)(out_channels) + self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) self.relu = nn.ReLU() def forward(self, x, mid) -> torch.Tensor: @@ -158,7 +158,7 @@ def __init__( self.residual_block = ResidualBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) - self.max_pool = maxpooling_factory(spatial_dims)( + self.max_pool = Pool[Pool.MAX, spatial_dims]( kernel_size=2, )