Skip to content

Commit

Permalink
Merge branch 'master' into 1533-fix-classificationsaver
Browse files Browse the repository at this point in the history
  • Loading branch information
Nic-Ma authored Feb 2, 2021
2 parents 46b3328 + 97ff1e0 commit b813c73
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 28 deletions.
5 changes: 2 additions & 3 deletions monai/networks/nets/ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,7 @@ def __init__(
self.dense4 = DenseBlock(spatial_dims, ndenselayer, num_init_features, densebn, densegrowth, 0.0)
noutdense4 = num_init_features + densegrowth * ndenselayer

if psp_block_num > 0:
self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode)
self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode)
self.final = Final(spatial_dims, psp_block_num + noutdense4, out_channels, upsample_mode)

# Initialise parameters
Expand Down Expand Up @@ -511,7 +510,7 @@ def forward(self, x):

sum4 = self.up3(d3) + conv_x
d4 = self.dense4(sum4)
if self.psp_block_num > 0 and self.psp is not None:
if self.psp_block_num > 0:
psp = self.psp(d4)
x = torch.cat((psp, d4), dim=1)
else:
Expand Down
18 changes: 10 additions & 8 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from collections.abc import Iterable
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -35,7 +35,7 @@
ShiftIntensity,
ThresholdIntensity,
)
from monai.utils import dtype_torch_to_numpy, ensure_tuple_size
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size

__all__ = [
"RandGaussianNoised",
Expand Down Expand Up @@ -110,27 +110,29 @@ def __init__(
) -> None:
super().__init__(keys)
self.prob = prob
self.mean = ensure_tuple_size(mean, len(self.keys))
self.mean = ensure_tuple_rep(mean, len(self.keys))
self.std = std
self._do_transform = False
self._noise: Optional[np.ndarray] = None
self._noise: List[np.ndarray] = []

def randomize(self, im_shape: Sequence[int]) -> None:
self._do_transform = self.R.random() < self.prob
self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape)
self._noise.clear()
for m in self.mean:
self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape))

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)

image_shape = d[self.keys[0]].shape # image shape from the first data key
self.randomize(image_shape)
if self._noise is None:
if len(self._noise) != len(self.keys):
raise AssertionError
if not self._do_transform:
return d
for key in self.keys:
for noise, key in zip(self._noise, self.keys):
dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype
d[key] = d[key] + self._noise.astype(dtype)
d[key] = d[key] + noise.astype(dtype)
return d


Expand Down
32 changes: 15 additions & 17 deletions tests/test_rand_gaussian_noised.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,36 @@
from monai.transforms import RandGaussianNoised
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D

TEST_CASE_0 = ["test_zero_mean", ["img"], 0, 0.1]
TEST_CASE_1 = ["test_non_zero_mean", ["img"], 1, 0.5]
TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1]
TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5]
TEST_CASES = [TEST_CASE_0, TEST_CASE_1]

seed = 0

# Test with numpy

def test_numpy_or_torch(keys, mean, std, imt):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({k: imt for k in keys})
np.random.seed(seed)
np.random.random()
for k in keys:
expected = imt + np.random.normal(mean, np.random.uniform(0, std), size=imt.shape)
np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)


# Test with numpy
class TestRandGaussianNoisedNumpy(NumpyImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({"img": self.imt})
np.random.seed(seed)
np.random.random()
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5)
test_numpy_or_torch(keys, mean, std, self.imt)


# Test with torch
class TestRandGaussianNoisedTorch(TorchImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({"img": self.imt})
np.random.seed(seed)
np.random.random()
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5)
test_numpy_or_torch(keys, mean, std, self.imt)


if __name__ == "__main__":
Expand Down

0 comments on commit b813c73

Please sign in to comment.