Skip to content

3517 Refine AddCoordinateChannels transform #3524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,6 @@ Spatial
:members:
:special-members: __call__

`AddCoordinateChannels`
"""""""""""""""""""""""
.. autoclass:: AddCoordinateChannels
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^

Expand Down Expand Up @@ -935,6 +928,12 @@ Utility
:members:
:special-members: __call__

`AddCoordinateChannels`
"""""""""""""""""""""""
.. autoclass:: AddCoordinateChannels
:members:
:special-members: __call__


Dictionary Transforms
---------------------
Expand Down Expand Up @@ -1519,12 +1518,6 @@ Spatial (Dict)
:members:
:special-members: __call__

`AddCoordinateChannelsd`
""""""""""""""""""""""""
.. autoclass:: AddCoordinateChannelsd
:members:
:special-members: __call__

`GridDistortiond`
"""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GridDistortiond.png
Expand Down Expand Up @@ -1767,6 +1760,12 @@ Utility (Dict)
:members:
:special-members: __call__

`AddCoordinateChannelsd`
""""""""""""""""""""""""
.. autoclass:: AddCoordinateChannelsd
:members:
:special-members: __call__

Transform Adaptors
------------------
.. automodule:: monai.transforms.adaptors
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@
from .smooth_field.array import RandSmoothFieldAdjustContrast, RandSmoothFieldAdjustIntensity, SmoothField
from .smooth_field.dictionary import RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd
from .spatial.array import (
AddCoordinateChannels,
Affine,
AffineGrid,
Flip,
Expand All @@ -305,9 +304,6 @@
Zoom,
)
from .spatial.dictionary import (
AddCoordinateChannelsd,
AddCoordinateChannelsD,
AddCoordinateChannelsDict,
Affined,
AffineD,
AffineDict,
Expand Down Expand Up @@ -366,6 +362,7 @@
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
from .utility.array import (
AddChannel,
AddCoordinateChannels,
AddExtremePointsChannel,
AsChannelFirst,
AsChannelLast,
Expand Down Expand Up @@ -401,6 +398,9 @@
AddChanneld,
AddChannelD,
AddChannelDict,
AddCoordinateChannelsd,
AddCoordinateChannelsD,
AddCoordinateChannelsDict,
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
Expand Down
45 changes: 0 additions & 45 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
create_translate,
map_spatial_axes,
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
Expand Down Expand Up @@ -77,7 +76,6 @@
"RandAffine",
"Rand2DElastic",
"Rand3DElastic",
"AddCoordinateChannels",
]

RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]
Expand Down Expand Up @@ -2024,49 +2022,6 @@ def __call__(
return out


class AddCoordinateChannels(Transform):
"""
Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling,
to allow feeding of the patch's location into the network.

This can be seen as a input-only version of CoordConv:

Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, spatial_channels: Sequence[int]) -> None:
"""
Args:
spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and
appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the
coordinates of the input's three spatial dimensions (0 is reserved for the channel dimension).
"""
self.spatial_channels = spatial_channels

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel first.
"""
if max(self.spatial_channels) > img.ndim - 1:
raise ValueError(
f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for "
f"dim {max(self.spatial_channels)}."
)
if 0 in self.spatial_channels:
raise ValueError("cannot add AddCoordinateChannels channel for dimension 0, as 0 is channel dim.")

spatial_dims = img.shape[1:]
coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij"))
coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore
# only keep required dimensions. need to subtract 1 since im will be 0-based
# but user input is 1-based (because channel dim is 0)
coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]]
return concatenate((img, coord_channels), axis=0)


class GridDistortion(Transform):

backend = [TransformBackends.TORCH]
Expand Down
33 changes: 0 additions & 33 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
AddCoordinateChannels,
Affine,
AffineGrid,
Flip,
Expand Down Expand Up @@ -123,9 +122,6 @@
"ZoomDict",
"RandZoomD",
"RandZoomDict",
"AddCoordinateChannelsd",
"AddCoordinateChannelsD",
"AddCoordinateChannelsDict",
]

GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
Expand Down Expand Up @@ -1756,34 +1752,6 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
return d


class AddCoordinateChannelsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`.
"""

backend = AddCoordinateChannels.backend

def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_missing_keys: bool = False) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
allow_missing_keys: don't raise exception if key is missing.
spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and
appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the
coordinates of the input's three spatial dimensions. It is assumed dimension 0 is the channel.

"""
super().__init__(keys, allow_missing_keys)
self.add_coordinate_channels = AddCoordinateChannels(spatial_channels)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.add_coordinate_channels(d[key])
return d


class GridDistortiond(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.GridDistortion`.
Expand Down Expand Up @@ -1919,4 +1887,3 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
RandRotateD = RandRotateDict = RandRotated
ZoomD = ZoomDict = Zoomd
RandZoomD = RandZoomDict = RandZoomd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
44 changes: 44 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
convert_to_cupy,
convert_to_numpy,
convert_to_tensor,
deprecated_arg,
ensure_tuple,
look_up_option,
min_version,
Expand All @@ -56,6 +57,7 @@
"AsChannelFirst",
"AsChannelLast",
"AddChannel",
"AddCoordinateChannels",
"EnsureChannelFirst",
"EnsureType",
"RepeatChannel",
Expand Down Expand Up @@ -1254,3 +1256,45 @@ def __call__(self, data):
if not self._do_transform:
return data
return super().__call__(data)


class AddCoordinateChannels(Transform):
"""
Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling,
to allow feeding of the patch's location into the network.

This can be seen as a input-only version of CoordConv:

Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018.

Args:
spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and
appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels
to the input image, encoding the coordinates of the input's three spatial dimensions.

.. deprecated:: 0.8.0
``spatial_channels`` is deprecated, use ``spatial_dims`` instead.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

@deprecated_arg(
name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead."
)
def __init__(self, spatial_dims: Sequence[int]) -> None:
self.spatial_dims = spatial_dims

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel first.
"""
if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0:
raise ValueError(f"`spatial_dims` values must be within [0, {img.ndim - 2}]")

spatial_size = img.shape[1:]
coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing="ij"))
coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore
coord_channels = coord_channels[list(self.spatial_dims)]
return concatenate((img, coord_channels), axis=0)
40 changes: 39 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utility.array import (
AddChannel,
AddCoordinateChannels,
AddExtremePointsChannel,
AsChannelFirst,
AsChannelLast,
Expand Down Expand Up @@ -61,14 +62,17 @@
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep
from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import TraceKeys, TransformBackends
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
"AddChannelD",
"AddChannelDict",
"AddChanneld",
"AddCoordinateChannelsD",
"AddCoordinateChannelsDict",
"AddCoordinateChannelsd",
"AddExtremePointsChannelD",
"AddExtremePointsChannelDict",
"AddExtremePointsChanneld",
Expand Down Expand Up @@ -1589,6 +1593,39 @@ def __call__(self, data):
return super().__call__(data)


class AddCoordinateChannelsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and
appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels
to the input image, encoding the coordinates of the input's three spatial dimensions.
allow_missing_keys: don't raise exception if key is missing.

.. deprecated:: 0.8.0
``spatial_channels`` is deprecated, use ``spatial_dims`` instead.

"""

backend = AddCoordinateChannels.backend

@deprecated_arg(
name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead."
)
def __init__(self, keys: KeysCollection, spatial_dims: Sequence[int], allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.add_coordinate_channels = AddCoordinateChannels(spatial_dims=spatial_dims)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.add_coordinate_channels(d[key])
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand Down Expand Up @@ -1627,3 +1664,4 @@ def __call__(self, data):
ToDeviceD = ToDeviceDict = ToDeviced
CuCIMD = CuCIMDict = CuCIMd
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
8 changes: 4 additions & 4 deletions tests/test_add_coordinate_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], []
for p in TEST_NDARRAYS:
TESTS.append([{"spatial_channels": (1, 2, 3)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)])
TESTS.append([{"spatial_channels": (1,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)])
TEST_CASES_ERROR_1.append([{"spatial_channels": (3,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])
TEST_CASES_ERROR_2.append([{"spatial_channels": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])
TESTS.append([{"spatial_dims": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)])
TESTS.append([{"spatial_dims": (0,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)])
TEST_CASES_ERROR_1.append([{"spatial_dims": (2,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])
TEST_CASES_ERROR_2.append([{"spatial_dims": (-1, 0, 1)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])


class TestAddCoordinateChannels(unittest.TestCase):
Expand Down
12 changes: 4 additions & 8 deletions tests/test_add_coordinate_channelsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,20 @@
for p in TEST_NDARRAYS:
TESTS.append(
[
{"spatial_channels": (1, 2, 3), "keys": ["img"]},
{"spatial_dims": (0, 1, 2), "keys": ["img"]},
{"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))},
(4, 3, 3, 3),
]
)
TESTS.append(
[
{"spatial_channels": (1,), "keys": ["img"]},
{"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))},
(2, 3, 3, 3),
]
[{"spatial_dims": (0,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, (2, 3, 3, 3)]
)

TEST_CASES_ERROR_1.append(
[{"spatial_channels": (3,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}]
[{"spatial_dims": (2,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}]
)
TEST_CASES_ERROR_2.append(
[{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}]
[{"spatial_dims": (-1, 0, 1), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}]
)


Expand Down