diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index ba88c35f8d..7d0b3bff79 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +from torch.nn.functional import interpolate from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock @@ -80,8 +81,22 @@ class DynUNet(nn.Module): upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: [``"batch"``, ``"instance"``, ``"group"``] feature normalization type and arguments. + deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. + If ``True``, in training mode, the forward function will output not only the last feature + map, but also the previous feature maps that come from the intermediate up sample layers. + In order to unify the return type (the restriction of TorchScript), all intermediate + feature maps are interpolated into the same size as the last feature map and stacked together + (with a new dimension in the first axis)into one single tensor. + For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and + (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor + will has the shape (1, 3, 2, 8, 6). + When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss + one by one with the groud truth, then do a weighted average for all losses to achieve the final loss. + (To be added: a corresponding tutorial link) + deep_supr_num: number of feature maps that will output during deep supervision head. The - value should be less than the number of up sample layers. Defaults to 1. + value should be larger than 0 and less than the number of up sample layers. + Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. Defaults to ``True``. """ @@ -95,6 +110,7 @@ def __init__( strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], norm_name: str = "instance", + deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, ): @@ -113,6 +129,7 @@ def __init__( self.bottleneck = self.get_bottleneck() self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) + self.deep_supervision = deep_supervision self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num self.apply(self.initialize_weights) @@ -140,6 +157,8 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): return bottleneck if index == 0: # don't associate a supervision head with self.input_block current_head, rest_heads = nn.Identity(), superheads + elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one + current_head, rest_heads = nn.Identity(), superheads[1:] else: current_head, rest_heads = superheads[0], superheads[1:] @@ -176,19 +195,21 @@ def check_kernel_stride(self): def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 - if deep_supr_num < 1 or deep_supr_num >= num_up_layers: + if deep_supr_num >= num_up_layers: raise AssertionError("deep_supr_num should be less than the number of up sample layers.") + if deep_supr_num < 1: + raise AssertionError("deep_supr_num should be larger than 0.") def forward(self, x): out = self.skip_layers(x) - return self.output_block(out) - - def get_feature_maps(self): - """ - Return the feature maps. - - """ - return self.heads[1 : self.deep_supr_num + 1] + out = self.output_block(out) + if self.training and self.deep_supervision: + out_all = [out] + feature_maps = self.heads[1 : self.deep_supr_num + 1] + for feature_map in feature_maps: + out_all.append(interpolate(feature_map, out.shape[2:])) + return torch.stack(out_all, dim=1) + return out def get_input_block(self): return self.conv_block( diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index d72c1fc48a..05e0c17465 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -43,6 +43,7 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", + "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size), @@ -65,6 +66,7 @@ "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), "norm_name": "instance", + "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size, in_size), @@ -77,6 +79,7 @@ for res_block in [True, False]: for deep_supr_num in [1, 2]: for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: + scale = strides[0] test_case = [ { "spatial_dims": spatial_dims, @@ -86,18 +89,13 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "group", + "deep_supervision": True, "deep_supr_num": deep_supr_num, "res_block": res_block, }, (1, 1, *[in_size] * spatial_dims), + (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), ] - scale = 1 - all_expected_shapes = [] - for stride in strides[: 1 + deep_supr_num]: - scale *= stride - deep_out_shape = (1, 2, *[in_size // scale] * spatial_dims) - all_expected_shapes.append(deep_out_shape) - test_case.append(all_expected_shapes) TEST_CASE_DEEP_SUPERVISION.append(test_case) @@ -121,11 +119,8 @@ class TestDynUNetDeepSupervision(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with torch.no_grad(): - results = [net(torch.randn(input_shape).to(device))] + net.get_feature_maps() - self.assertEqual(len(results), len(expected_shape)) - for idx in range(len(results)): - result, sub_expected_shape = results[idx], expected_shape[idx] - self.assertEqual(result.shape, sub_expected_shape) + results = net(torch.randn(input_shape).to(device)) + self.assertEqual(results.shape, expected_shape) if __name__ == "__main__":