diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index de16cd4cca..93a77d7b2a 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -23,7 +23,7 @@ from monai.networks.blocks.encoder import BaseEncoder from monai.networks.layers.factories import Conv, Norm, Pool -from monai.networks.layers.utils import get_pool_layer +from monai.networks.layers.utils import get_act_layer, get_pool_layer from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option, optional_import @@ -78,6 +78,7 @@ def __init__( spatial_dims: int = 3, stride: int = 1, downsample: nn.Module | partial | None = None, + act: str | tuple = ("relu", {"inplace": True}), ) -> None: """ Args: @@ -86,6 +87,7 @@ def __init__( spatial_dims: number of spatial dimensions of the input image. stride: stride to use for first conv layer. downsample: which downsample layer to use. + act: activation type and arguments. Defaults to relu. """ super().__init__() @@ -94,7 +96,7 @@ def __init__( self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) self.bn1 = norm_type(planes) - self.relu = nn.ReLU(inplace=True) + self.act = get_act_layer(name=act) self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) self.bn2 = norm_type(planes) self.downsample = downsample @@ -105,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act(out) out = self.conv2(out) out = self.bn2(out) @@ -114,7 +116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act(out) return out @@ -129,6 +131,7 @@ def __init__( spatial_dims: int = 3, stride: int = 1, downsample: nn.Module | partial | None = None, + act: str | tuple = ("relu", {"inplace": True}), ) -> None: """ Args: @@ -137,6 +140,7 @@ def __init__( spatial_dims: number of spatial dimensions of the input image. stride: stride to use for second conv layer. downsample: which downsample layer to use. + act: activation type and arguments. Defaults to relu. """ super().__init__() @@ -150,7 +154,7 @@ def __init__( self.bn2 = norm_type(planes) self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = norm_type(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) + self.act = get_act_layer(name=act) self.downsample = downsample self.stride = stride @@ -159,11 +163,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act(out) out = self.conv2(out) out = self.bn2(out) - out = self.relu(out) + out = self.act(out) out = self.conv3(out) out = self.bn3(out) @@ -172,7 +176,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act(out) return out @@ -202,6 +206,7 @@ class ResNet(nn.Module): num_classes: number of output (classifications). feed_forward: whether to add the FC layer for the output, default to `True`. bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. + act: activation type and arguments. Defaults to relu. """ @@ -220,6 +225,7 @@ def __init__( num_classes: int = 400, feed_forward: bool = True, bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) + act: str | tuple = ("relu", {"inplace": True}), ) -> None: super().__init__() @@ -257,7 +263,7 @@ def __init__( bias=False, ) self.bn1 = norm_type(self.in_planes) - self.relu = nn.ReLU(inplace=True) + self.act = get_act_layer(name=act) self.maxpool = pool_type(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type) self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2) @@ -329,7 +335,7 @@ def _make_layer( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.act(x) if not self.no_max_pool: x = self.maxpool(x) @@ -396,7 +402,7 @@ def forward(self, inputs: torch.Tensor): """ x = self.conv1(inputs) x = self.bn1(x) - x = self.relu(x) + x = self.act(x) features = [] features.append(x) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ecab133154..3a58d1c955 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -107,6 +107,7 @@ "num_classes": 3, "conv1_t_size": [3], "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), }, (1, 2, 32), (1, 3), @@ -185,13 +186,29 @@ (1, 3), ] +TEST_CASE_8 = [ + { + "block": "bottleneck", + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] PRETRAINED_TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) PRETRAINED_TEST_CASES.append([model, *case]) -for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: +for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]: TEST_CASES.append([ResNet, *case]) TEST_SCRIPT_CASES = [