diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 6e61db07ca..d62722478e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -510,7 +510,7 @@ def _resnet( # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and ( - bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True + bias_downsample == kwargs.get("bias_downsample", True) ): # Download the MedicalNet pretrained model model_state_dict = get_pretrained_resnet_medicalnet( @@ -518,8 +518,7 @@ def _resnet( ) else: raise NotImplementedError( - f"Please set shortcut_type to {shortcut_type} and bias_downsample to" - f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" + f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} " f"when using pretrained MedicalNet resnet{resnet_depth}" ) else: @@ -681,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): # After testing # False: 10, 50, 101, 152, 200 # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + bias_downsample = resnet_depth in (18, 34) shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type diff --git a/tests/test_resnet.py b/tests/test_resnet.py index e873f1238a..a55d18f5de 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -266,7 +266,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): @parameterized.expand(PRETRAINED_TEST_CASES) @skip_if_quick @skip_if_no_cuda - def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): + def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape): net = model(**input_param).to(device) # Save ckpt torch.save(net.state_dict(), self.tmp_ckpt_filename) @@ -290,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape and input_param.get("n_input_channels", 3) == 1 and input_param.get("feed_forward", True) is False and input_param.get("shortcut_type", "B") == shortcut_type - and ( - input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True - ) + and (input_param.get("bias_downsample", True) == bias_downsample) ): model(**cp_input_param) else: @@ -303,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape cp_input_param["n_input_channels"] = 1 cp_input_param["feed_forward"] = False cp_input_param["shortcut_type"] = shortcut_type - cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True + cp_input_param["bias_downsample"] = bias_downsample if cp_input_param.get("spatial_dims", 3) == 3: with skip_if_downloading_fails(): pretrained_net = model(**cp_input_param).to(device)