Skip to content

Commit

Permalink
Intensity augmentation transform (#3784)
Browse files Browse the repository at this point in the history
* intensity augmentation transforms

Signed-off-by: yc7620 <yaniel.cabrera20@imperial.ac.uk>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* mapping via bitbucket operation

Signed-off-by: yc7620 <yaniel.cabrera20@imperial.ac.uk>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

Co-authored-by: yc7620 <yaniel.cabrera20@imperial.ac.uk>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
Co-authored-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
5 people authored Feb 22, 2022
1 parent fd63af4 commit 56205f5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
GaussianSmooth,
GibbsNoise,
HistogramNormalize,
IntensityRemap,
KSpaceSpikeNoise,
MaskIntensity,
NormalizeIntensity,
Expand All @@ -98,6 +99,7 @@
RandGaussianSmooth,
RandGibbsNoise,
RandHistogramShift,
RandIntensityRemap,
RandKSpaceSpikeNoise,
RandRicianNoise,
RandScaleIntensity,
Expand Down
112 changes: 112 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"RandCoarseDropout",
"RandCoarseShuffle",
"HistogramNormalize",
"IntensityRemap",
"RandIntensityRemap",
]


Expand Down Expand Up @@ -2053,3 +2055,113 @@ def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None)
out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype)

return out


class IntensityRemap(RandomizableTransform):
"""
Transform for intensity remapping of images. The intensity at each
pixel is replaced by a new values coming from an intensity remappping
curve.
The remapping curve is created by uniformly sampling values from the
possible intensities for the input image and then adding a linear
component. The curve is the rescaled to the input image intensity range.
Intended to be used as a means to data augmentation via:
:py:class:`monai.transforms.RandIntensityRemap`.
Implementation is described in the work:
`Intensity augmentation for domain transfer of whole breast segmentation
in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.
Args:
kernel_size: window size for averaging operation for the remapping
curve.
slope: slope of the linear component. Easiest to leave default value
and tune the kernel_size parameter instead.
return_map: set to True for the transform to return a dictionary version
of the lookup table used in the intensity remapping. The keys
correspond to the old intensities, and the values are the new
values.
"""

def __init__(self, kernel_size: int = 30, slope: float = 0.7):

super().__init__()

self.kernel_size = kernel_size
self.slope = slope

def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: image to remap.
"""

img = img.clone()
# sample noise
vals_to_sample = torch.unique(img).tolist()
noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size))
# smooth
noise = torch.nn.AvgPool1d(self.kernel_size, stride=1)(noise.unsqueeze(0)).squeeze()
# add linear component
grid = torch.arange(len(noise)) / len(noise)
noise += self.slope * grid
# rescale
noise = (noise - noise.min()) / (noise.max() - noise.min()) * img.max() + img.min()

# intensity remapping function
index_img = torch.bucketize(img, torch.tensor(vals_to_sample))
img = noise[index_img]

return img


class RandIntensityRemap(RandomizableTransform):
"""
Transform for intensity remapping of images. The intensity at each
pixel is replaced by a new values coming from an intensity remappping
curve.
The remapping curve is created by uniformly sampling values from the
possible intensities for the input image and then adding a linear
component. The curve is the rescaled to the input image intensity range.
Implementation is described in the work:
`Intensity augmentation for domain transfer of whole breast segmentation
in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.
Args:
prob: probability of applying the transform.
kernel_size: window size for averaging operation for the remapping
curve.
slope: slope of the linear component. Easiest to leave default value
and tune the kernel_size parameter instead.
channel_wise: set to True to treat each channel independently.
"""

def __init__(self, prob: float = 0.1, kernel_size: int = 30, slope: float = 0.7, channel_wise: bool = True):

RandomizableTransform.__init__(self, prob=prob)
self.kernel_size = kernel_size
self.slope = slope
self.channel_wise = True

def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: image to remap.
"""
super().randomize(None)
if self._do_transform:
if self.channel_wise:
img = torch.stack(
[
IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i])
for i in range(len(img))
]
)
else:
img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img)

return img

0 comments on commit 56205f5

Please sign in to comment.