From bbb8008d024cf786c5f6d406ec82931d87ffb9f5 Mon Sep 17 00:00:00 2001 From: Andres Date: Mon, 13 Sep 2021 21:20:05 +0100 Subject: [PATCH 1/5] Add dropout arg in DynUNet init Signed-off-by: Andres --- monai/networks/blocks/dynunet_block.py | 19 +++++++++++++++++-- monai/networks/blocks/dynunet_block_v1.py | 10 ++++++++++ monai/networks/nets/dynunet.py | 13 ++++++++----- monai/networks/nets/dynunet_v1.py | 3 +++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index bb654d841c..dd713c5149 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -33,6 +33,7 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -44,6 +45,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetResBlock, self).__init__() self.conv1 = get_conv_layer( @@ -52,6 +54,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -60,6 +63,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.conv3 = get_conv_layer( @@ -68,6 +72,7 @@ def __init__( out_channels, kernel_size=1, stride=stride, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -107,6 +112,7 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -118,6 +124,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetBasicBlock, self).__init__() self.conv1 = get_conv_layer( @@ -126,6 +133,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -134,6 +142,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -164,6 +173,7 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. + dropout: dropout probability """ @@ -176,6 +186,7 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + dropout: Optional[Union[Tuple, str, float]] = None, ): super(UnetUpBlock, self).__init__() upsample_stride = upsample_kernel_size @@ -185,6 +196,7 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, conv_only=True, is_transposed=True, ) @@ -194,6 +206,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, norm_name=norm_name, ) @@ -206,10 +219,10 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, dropout: float = 0.0): super(UnetOutBlock, self).__init__() self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, bias=True, conv_only=True + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True ) def forward(self, inp): @@ -224,6 +237,7 @@ def get_conv_layer( stride: Union[Sequence[int], int] = 1, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, @@ -240,6 +254,7 @@ def get_conv_layer( kernel_size=kernel_size, act=act, norm=norm, + dropout=dropout, bias=bias, conv_only=conv_only, is_transposed=is_transposed, diff --git a/monai/networks/blocks/dynunet_block_v1.py b/monai/networks/blocks/dynunet_block_v1.py index d5d9bbf3dc..b5b88dd0df 100644 --- a/monai/networks/blocks/dynunet_block_v1.py +++ b/monai/networks/blocks/dynunet_block_v1.py @@ -32,6 +32,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) self.conv1 = get_conv_layer( @@ -40,6 +41,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -48,6 +50,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.conv3 = get_conv_layer( @@ -56,6 +59,7 @@ def __init__( out_channels, kernel_size=1, stride=stride, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -81,6 +85,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) self.conv1 = get_conv_layer( @@ -89,6 +94,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( @@ -97,6 +103,7 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, conv_only=True, ) self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) @@ -118,6 +125,7 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: str, + dropout: float = 0.0, ): nn.Module.__init__(self) upsample_stride = upsample_kernel_size @@ -127,6 +135,7 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, conv_only=True, is_transposed=True, ) @@ -137,6 +146,7 @@ def __init__( kernel_size=kernel_size, stride=1, norm_name=norm_name, + dropout=dropout, ) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4af70b22c7..94511aaf62 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -86,6 +86,7 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the last feature @@ -115,6 +116,7 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + dropout: float = 0.0, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), deep_supervision: bool = False, deep_supr_num: int = 1, @@ -128,6 +130,7 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() @@ -225,6 +228,7 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + dropout=self.dropout, ) def get_bottleneck(self): @@ -235,14 +239,11 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + dropout=self.dropout, ) def get_output_block(self, idx: int): - return UnetOutBlock( - self.spatial_dims, - self.filters[idx], - self.out_channels, - ) + return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels, dropout=self.dropout) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] @@ -276,6 +277,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "dropout": self.dropout, "upsample_kernel_size": up_kernel, } layer = conv_block(**params) @@ -289,6 +291,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "dropout": self.dropout, } layer = conv_block(**params) layers.append(layer) diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py index feb05d1762..c6a54807e4 100644 --- a/monai/networks/nets/dynunet_v1.py +++ b/monai/networks/nets/dynunet_v1.py @@ -38,6 +38,7 @@ class DynUNetV1(DynUNet): kernel_size: convolution kernel size. strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. + dropout: dropout ratio. Defaults to no dropout. norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. @@ -57,6 +58,7 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + dropout: float = 0.0, norm_name: str = "instance", deep_supervision: bool = False, deep_supr_num: int = 1, @@ -70,6 +72,7 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.dropout = dropout self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() From 73b8f7edebfd31ef18460e64dcfb0b1cb809cfa0 Mon Sep 17 00:00:00 2001 From: Andres Date: Mon, 13 Sep 2021 22:31:48 +0100 Subject: [PATCH 2/5] Change to new API - dropout arg Signed-off-by: Andres --- monai/networks/nets/dynunet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 94511aaf62..9f58de355e 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -116,7 +116,7 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], - dropout: float = 0.0, + dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), deep_supervision: bool = False, deep_supr_num: int = 1, From c71abe6838a0ae856bd5ce45acc685b45290f13e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 22:40:48 +0100 Subject: [PATCH 3/5] add test cases Signed-off-by: Wenqi Li --- monai/networks/nets/dynunet.py | 2 +- tests/test_dynunet.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 9f58de355e..d65cd9f5f4 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -187,7 +187,7 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." - if not (len(kernels) == len(strides) and len(kernels) >= 3): + if len(kernels) != len(strides) or len(kernels) < 3: raise AssertionError(error_msg) for idx, k_i in enumerate(kernels): diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 81ed239461..18fe146a40 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -26,14 +26,14 @@ expected_shape: Sequence[Any] TEST_CASE_DYNUNET_2D = [] +out_channels = 2 +in_size = 64 +spatial_dims = 2 for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: + expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) for in_channels in [2, 3]: for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) test_case = [ { "spatial_dims": spatial_dims, @@ -45,6 +45,7 @@ "norm_name": "batch", "deep_supervision": False, "res_block": res_block, + "dropout": None, }, (1, in_channels, in_size, in_size), expected_shape, @@ -52,11 +53,11 @@ TEST_CASE_DYNUNET_2D.append(test_case) TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides +in_channels = 1 +in_size = 64 for out_channels in [2, 3]: + expected_shape = (1, out_channels, 64, 32, 64) for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) test_case = [ { "spatial_dims": 3, @@ -68,6 +69,7 @@ "norm_name": ("INSTANCE", {"affine": True}), "deep_supervision": False, "res_block": res_block, + "dropout": ("alphadropout", {"p": 0.25}), }, (1, in_channels, in_size, in_size, in_size), expected_shape, From 536432fb8ed5084ce9e9bf14910384be4af3b3f2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 22:34:28 +0100 Subject: [PATCH 4/5] ThresholdIntensity, ThresholdIntensityd (#2944) * ThresholdIntensity, ThresholdIntensityd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/intensity/array.py | 12 +++-- monai/transforms/intensity/dictionary.py | 4 +- .../utils_pytorch_numpy_unification.py | 15 ++++++ tests/test_threshold_intensity.py | 19 +++---- tests/test_threshold_intensityd.py | 52 +++++++++++-------- 6 files changed, 68 insertions(+), 36 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a07dee867b..9575a412b4 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -524,4 +524,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis +from .utils_pytorch_numpy_unification import in1d, moveaxis, where diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 311534aa8b..b6fa1f72b7 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -28,6 +28,7 @@ from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils_pytorch_numpy_unification import where from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -655,6 +656,8 @@ class ThresholdIntensity(Transform): cval: value to fill the remaining parts of the image, default is 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): raise ValueError("threshold must be a float or int number.") @@ -662,13 +665,14 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + mask = img > self.threshold if self.above else img < self.threshold + res = where(mask, img, self.cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res class ScaleIntensityRange(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 22b1edd5fd..cccf3e2a90 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -664,6 +664,8 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ThresholdIntensity.backend + def __init__( self, keys: KeysCollection, @@ -675,7 +677,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2eebe3eda3..70ecb2848d 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -17,6 +17,7 @@ __all__ = [ "moveaxis", "in1d", + "where", ] @@ -50,3 +51,17 @@ def in1d(x, y): if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + + +def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: + """ + Note that `torch.where` may convert y.dtype to x.dtype. + """ + result: NdarrayOrTensor + if isinstance(condition, np.ndarray): + result = np.where(condition, x, y) + else: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + return result diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..0614514456 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -15,20 +15,21 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) - np.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value) if __name__ == "__main__": diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..398f9cfe91 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -15,31 +15,41 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) - np.testing.assert_allclose(result["image"], expected_value) - np.testing.assert_allclose(result["label"], expected_value) - np.testing.assert_allclose(result["extra"], expected_value) + assert_allclose(result["image"], expected_value) + assert_allclose(result["label"], expected_value) + assert_allclose(result["extra"], expected_value) if __name__ == "__main__": From 6967929a2be2c2d97a2034b58ad4c055d67873a2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 22:46:41 +0100 Subject: [PATCH 5/5] fixes typing Signed-off-by: Wenqi Li --- monai/networks/blocks/dynunet_block.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index dd713c5149..fc37fc8999 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -219,7 +219,9 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, dropout: float = 0.0): + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): super(UnetOutBlock, self).__init__() self.conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True