Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2679 Add IntensityStats transform to record intensity statistics into meta data #2685

Merged
merged 21 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,13 @@ Utility
:members:
:special-members: __call__

`IntensityStats`
""""""""""""""""
.. autoclass:: IntensityStats
:members:
:special-members: __call__


Dictionary Transforms
---------------------

Expand Down Expand Up @@ -911,6 +918,7 @@ Intensity (Dict)
:members:
:special-members: __call__


IO (Dict)
^^^^^^^^^

Expand Down Expand Up @@ -1265,6 +1273,13 @@ Utility (Dict)
:members:
:special-members: __call__

`IntensityStatsd`
"""""""""""""""""
.. autoclass:: IntensityStatsd
:members:
:special-members: __call__


Transform Adaptors
------------------
.. automodule:: monai.transforms.adaptors
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
EnsureType,
FgBgToIndices,
Identity,
IntensityStats,
LabelToMask,
Lambda,
MapLabelValue,
Expand Down Expand Up @@ -390,6 +391,9 @@
Identityd,
IdentityD,
IdentityDict,
IntensityStatsd,
IntensityStatsD,
IntensityStatsDict,
LabelToMaskd,
LabelToMaskD,
LabelToMaskDict,
Expand Down
18 changes: 13 additions & 5 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,13 @@ class ShiftIntensity(Transform):
def __init__(self, offset: float) -> None:
self.offset = offset

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: np.ndarray, offset: Optional[float] = None) -> np.ndarray:
"""
Apply the transform to `img`.
"""
return np.asarray((img + self.offset), dtype=img.dtype)

offset = self.offset if offset is None else offset
return np.asarray((img + offset), dtype=img.dtype)


class RandShiftIntensity(RandomizableTransform):
Expand All @@ -214,20 +216,26 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1
raise AssertionError("offsets should be a number or pair of numbers.")
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]
self._shfiter = ShiftIntensity(self._offset)

def randomize(self, data: Optional[Any] = None) -> None:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
super().randomize(None)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: np.ndarray, factor: Optional[float] = None) -> np.ndarray:
"""
Apply the transform to `img`.

Args:
img: input image to shift intensity.
factor: a factor to multiply the random offset, then shift.
can be some image specific value at runtime, like: max(img), etc.

"""
self.randomize()
if not self._do_transform:
return img
shifter = ShiftIntensity(self._offset)
return shifter(img)
return self._shfiter(img, self._offset if factor is None else self._offset * factor)


class StdShiftIntensity(Transform):
Expand Down
77 changes: 68 additions & 9 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ThresholdIntensity,
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
from monai.utils import dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple

__all__ = [
"RandGaussianNoised",
Expand Down Expand Up @@ -232,21 +232,53 @@ class ShiftIntensityd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`.
"""

def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool = False) -> None:
def __init__(
self,
keys: KeysCollection,
offset: float,
factor_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
offset: offset value to shift the intensity of image.
factor_key: if not None, use it as the key to extract a value from the corresponding
meta data dictionary of `key` at runtime, and multiply the `offset` to shift intensity.
Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values
and store in the meta data.
it also can be a sequence of strings, map to `keys`.
meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
used to extract the factor value is `factor_key` is not None.
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according
to the key data, default is `meta_dict`, the meta data is a dictionary object.
used to extract the factor value is `factor_key` is not None.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.factor_key = ensure_tuple_rep(factor_key, len(self.keys))
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.shifter = ShiftIntensity(offset)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.shifter(d[key])
for key, factor_key, meta_key, meta_key_postfix in self.key_iterator(
d, self.factor_key, self.meta_keys, self.meta_key_postfix
):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None
offset = None if factor is None else self.shifter.offset * factor
d[key] = self.shifter(d[key], offset=offset)
return d


Expand All @@ -259,6 +291,9 @@ def __init__(
self,
keys: KeysCollection,
offsets: Union[Tuple[float, float], float],
factor_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
prob: float = 0.1,
allow_missing_keys: bool = False,
) -> None:
Expand All @@ -268,6 +303,20 @@ def __init__(
See also: :py:class:`monai.transforms.compose.MapTransform`
offsets: offset range to randomly shift.
if single number, offset value is picked from (-offsets, offsets).
factor_key: if not None, use it as the key to extract a value from the corresponding
meta data dictionary of `key` at runtime, and multiply the random `offset` to shift intensity.
Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values
and store in the meta data.
it also can be a sequence of strings, map to `keys`.
meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
used to extract the factor value is `factor_key` is not None.
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according
to the key data, default is `meta_dict`, the meta data is a dictionary object.
used to extract the factor value is `factor_key` is not None.
prob: probability of rotating.
(Default 0.1, with 10% probability it returns a rotated array.)
allow_missing_keys: don't raise exception if key is missing.
Expand All @@ -282,19 +331,29 @@ def __init__(
raise AssertionError("offsets should be a number or pair of numbers.")
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]
self.factor_key = ensure_tuple_rep(factor_key, len(self.keys))
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.shifter = ShiftIntensity(self._offset)

def randomize(self, data: Optional[Any] = None) -> None:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
super().randomize(None)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data) -> Dict[Hashable, np.ndarray]:
d = dict(data)
self.randomize()
if not self._do_transform:
return d
shifter = ShiftIntensity(self._offset)
for key in self.key_iterator(d):
d[key] = shifter(d[key])
for key, factor_key, meta_key, meta_key_postfix in self.key_iterator(
d, self.factor_key, self.meta_keys, self.meta_key_postfix
):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None
offset = self._offset if factor is None else self._offset * factor
d[key] = self.shifter(d[key], offset=offset)
return d


Expand Down
78 changes: 76 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sys
import time
import warnings
from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -32,7 +32,7 @@
map_binary_to_indices,
map_classes_to_indices,
)
from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import
from monai.utils import ensure_tuple, issequenceiterable, look_up_option, min_version, optional_import

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -66,6 +66,7 @@
"AddExtremePointsChannel",
"TorchVision",
"MapLabelValue",
"IntensityStats",
]


Expand Down Expand Up @@ -938,3 +939,76 @@ def __call__(self, img: np.ndarray):
np.place(out_flat, img_flat == o, t)

return out_flat.reshape(img.shape)


class IntensityStats(Transform):
"""
Compute statistics for the intensity values of input image and store into the meta data dictionary.
For example: if `ops=[lambda x: np.mean(x), "max"]` and `key_prefix="orig"`, may generate below stats:
`{"orig_custom_0": 1.5, "orig_max": 3.0}`.

Args:
ops: expected operations to compute statistics for the intensity.
if a string, will map to the predefined operations, supported: ["mean", "median", "max", "min", "std"]
mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`.
if a callable function, will execute the function on input image.
key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the
meta data dictionary. if some `ops` are callable functions, will use "{key_prefix}_custom_{index}"
as the key, where index counts from 0.
channel_wise: whether to compute statistics for every channel of input image separately.
if True, return a list of values for every operation, default to False.

"""

def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None:
self.supported_ops = {
"mean": lambda x: np.nanmean(x),
"median": lambda x: np.nanmedian(x),
"max": lambda x: np.nanmax(x),
"min": lambda x: np.nanmin(x),
"std": lambda x: np.nanstd(x),
}
self.ops = [o if callable(o) else look_up_option(o, self.supported_ops.keys()) for o in ensure_tuple(ops)]
self.key_prefix = key_prefix
self.channel_wise = channel_wise

def __call__(
self,
img: np.ndarray,
meta_data: Optional[Dict] = None,
mask: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Dict]:
"""
Compute statistics for the intensity of input image.

Args:
img: input image to compute intensity stats.
meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary.
mask: if not None, mask the image to extract only the interested area to compute statistics.
mask must have the same shape as input `img`.

"""
if meta_data is None:
meta_data = {}

img_: np.ndarray = img
if mask is not None:
if mask.shape != img.shape or mask.dtype != bool:
raise TypeError("mask must be bool array with the same shape as input `img`.")
img_ = img[mask]

def _compute(op: Callable, data: np.ndarray):
if self.channel_wise:
return [op(c) for c in data]
else:
return op(data)

custom_index = 0
for o in self.ops:
if isinstance(o, str):
meta_data[self.key_prefix + "_" + o] = _compute(self.supported_ops[o], img_)
elif callable(o):
meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_)
custom_index += 1

return img, meta_data
Loading