Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify dynunet forward function #1596

Merged
merged 2 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
from torch.nn.functional import interpolate

from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock

Expand Down Expand Up @@ -80,8 +81,22 @@ class DynUNet(nn.Module):
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: [``"batch"``, ``"instance"``, ``"group"``]
feature normalization type and arguments.
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
If ``True``, in training mode, the forward function will output not only the last feature
map, but also the previous feature maps that come from the intermediate up sample layers.
In order to unify the return type (the restriction of TorchScript), all intermediate
feature maps are interpolated into the same size as the last feature map and stacked together
(with a new dimension in the first axis)into one single tensor.
For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and
(1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
will has the shape (1, 3, 2, 8, 6).
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
one by one with the groud truth, then do a weighted average for all losses to achieve the final loss.
(To be added: a corresponding tutorial link)

deep_supr_num: number of feature maps that will output during deep supervision head. The
value should be less than the number of up sample layers. Defaults to 1.
value should be larger than 0 and less than the number of up sample layers.
Defaults to 1.
res_block: whether to use residual connection based convolution blocks during the network.
Defaults to ``True``.
"""
Expand All @@ -95,6 +110,7 @@ def __init__(
strides: Sequence[Union[Sequence[int], int]],
upsample_kernel_size: Sequence[Union[Sequence[int], int]],
norm_name: str = "instance",
deep_supervision: bool = False,
deep_supr_num: int = 1,
res_block: bool = False,
):
Expand All @@ -113,6 +129,7 @@ def __init__(
self.bottleneck = self.get_bottleneck()
self.upsamples = self.get_upsamples()
self.output_block = self.get_output_block(0)
self.deep_supervision = deep_supervision
self.deep_supervision_heads = self.get_deep_supervision_heads()
self.deep_supr_num = deep_supr_num
self.apply(self.initialize_weights)
Expand Down Expand Up @@ -140,6 +157,8 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
return bottleneck
if index == 0: # don't associate a supervision head with self.input_block
current_head, rest_heads = nn.Identity(), superheads
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
current_head, rest_heads = nn.Identity(), superheads[1:]
else:
current_head, rest_heads = superheads[0], superheads[1:]

Expand Down Expand Up @@ -176,19 +195,21 @@ def check_kernel_stride(self):
def check_deep_supr_num(self):
deep_supr_num, strides = self.deep_supr_num, self.strides
num_up_layers = len(strides) - 1
if deep_supr_num < 1 or deep_supr_num >= num_up_layers:
if deep_supr_num >= num_up_layers:
raise AssertionError("deep_supr_num should be less than the number of up sample layers.")
if deep_supr_num < 1:
raise AssertionError("deep_supr_num should be larger than 0.")

def forward(self, x):
out = self.skip_layers(x)
return self.output_block(out)

def get_feature_maps(self):
"""
Return the feature maps.

"""
return self.heads[1 : self.deep_supr_num + 1]
out = self.output_block(out)
if self.training and self.deep_supervision:
out_all = [out]
feature_maps = self.heads[1 : self.deep_supr_num + 1]
for feature_map in feature_maps:
out_all.append(interpolate(feature_map, out.shape[2:]))
return torch.stack(out_all, dim=1)
return out

def get_input_block(self):
return self.conv_block(
Expand Down
19 changes: 7 additions & 12 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"strides": strides,
"upsample_kernel_size": strides[1:],
"norm_name": "batch",
"deep_supervision": False,
"res_block": res_block,
},
(1, in_channels, in_size, in_size),
Expand All @@ -65,6 +66,7 @@
"strides": ((1, 2, 1), 2, 2, 1),
"upsample_kernel_size": (2, 2, 1),
"norm_name": "instance",
"deep_supervision": False,
"res_block": res_block,
},
(1, in_channels, in_size, in_size, in_size),
Expand All @@ -77,6 +79,7 @@
for res_block in [True, False]:
for deep_supr_num in [1, 2]:
for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]:
scale = strides[0]
test_case = [
{
"spatial_dims": spatial_dims,
Expand All @@ -86,18 +89,13 @@
"strides": strides,
"upsample_kernel_size": strides[1:],
"norm_name": "group",
"deep_supervision": True,
"deep_supr_num": deep_supr_num,
"res_block": res_block,
},
(1, 1, *[in_size] * spatial_dims),
(1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims),
]
scale = 1
all_expected_shapes = []
for stride in strides[: 1 + deep_supr_num]:
scale *= stride
deep_out_shape = (1, 2, *[in_size // scale] * spatial_dims)
all_expected_shapes.append(deep_out_shape)
test_case.append(all_expected_shapes)
TEST_CASE_DEEP_SUPERVISION.append(test_case)


Expand All @@ -121,11 +119,8 @@ class TestDynUNetDeepSupervision(unittest.TestCase):
def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
with torch.no_grad():
results = [net(torch.randn(input_shape).to(device))] + net.get_feature_maps()
self.assertEqual(len(results), len(expected_shape))
for idx in range(len(results)):
result, sub_expected_shape = results[idx], expected_shape[idx]
self.assertEqual(result.shape, sub_expected_shape)
results = net(torch.randn(input_shape).to(device))
self.assertEqual(results.shape, expected_shape)


if __name__ == "__main__":
Expand Down