Skip to content

Commit

Permalink
Merge 69ff653 into 26f8446
Browse files Browse the repository at this point in the history
  • Loading branch information
rijobro authored Aug 25, 2021
2 parents 26f8446 + 69ff653 commit 9266581
Show file tree
Hide file tree
Showing 15 changed files with 174 additions and 89 deletions.
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import moveaxis
37 changes: 25 additions & 12 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
from monai.utils import (
convert_to_numpy,
convert_to_tensor,
Expand Down Expand Up @@ -82,17 +83,18 @@

class Identity(Transform):
"""
Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data.
Do nothing to the data.
As the output value is same as input, it can be used as a testing tool to verify the transform chain,
Compose or transform adaptor, etc.
"""

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
return np.asanyarray(img)
return img


class AsChannelFirst(Transform):
Expand All @@ -111,16 +113,18 @@ class AsChannelFirst(Transform):
channel_dim: which dimension of input image is the channel, default is the last dimension.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, channel_dim: int = -1) -> None:
if not (isinstance(channel_dim, int) and channel_dim >= -1):
raise AssertionError("invalid channel dimension.")
self.channel_dim = channel_dim

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
return np.moveaxis(img, self.channel_dim, 0)
return moveaxis(img, self.channel_dim, 0)


class AsChannelLast(Transform):
Expand All @@ -138,16 +142,18 @@ class AsChannelLast(Transform):
channel_dim: which dimension of input image is the channel, default is the first dimension.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, channel_dim: int = 0) -> None:
if not (isinstance(channel_dim, int) and channel_dim >= -1):
raise AssertionError("invalid channel dimension.")
self.channel_dim = channel_dim

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
return np.moveaxis(img, self.channel_dim, -1)
return moveaxis(img, self.channel_dim, -1)


class AddChannel(Transform):
Expand All @@ -164,7 +170,9 @@ class AddChannel(Transform):
transforms.
"""

def __call__(self, img: NdarrayTensor):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand All @@ -179,14 +187,16 @@ class EnsureChannelFirst(Transform):
Convert the data to `channel_first` based on the `original_channel_dim` information.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, strict_check: bool = True):
"""
Args:
strict_check: whether to raise an error when the meta information is insufficient.
"""
self.strict_check = strict_check

def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None):
def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand Down Expand Up @@ -220,16 +230,19 @@ class RepeatChannel(Transform):
repeats: the number of repetitions for each element.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, repeats: int) -> None:
if repeats <= 0:
raise AssertionError("repeats count must be greater than 0.")
self.repeats = repeats

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is a "channel-first" array.
"""
return np.repeat(img, self.repeats, 0)
repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat
return repeeat_fn(img, self.repeats, 0) # type: ignore


class RemoveRepeatedChannel(Transform):
Expand Down
26 changes: 18 additions & 8 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class Identityd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.Identity`.
"""

backend = Identity.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -180,9 +182,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
super().__init__(keys, allow_missing_keys)
self.identity = Identity()

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.identity(d[key])
Expand All @@ -194,6 +194,8 @@ class AsChannelFirstd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`.
"""

backend = AsChannelFirst.backend

def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -205,7 +207,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke
super().__init__(keys, allow_missing_keys)
self.converter = AsChannelFirst(channel_dim=channel_dim)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand All @@ -217,6 +219,8 @@ class AsChannelLastd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`.
"""

backend = AsChannelLast.backend

def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -228,7 +232,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key
super().__init__(keys, allow_missing_keys)
self.converter = AsChannelLast(channel_dim=channel_dim)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand All @@ -240,6 +244,8 @@ class AddChanneld(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`.
"""

backend = AddChannel.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -250,7 +256,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
super().__init__(keys, allow_missing_keys)
self.adder = AddChannel()

def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.adder(d[key])
Expand All @@ -262,6 +268,8 @@ class EnsureChannelFirstd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`.
"""

backend = EnsureChannelFirst.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -289,7 +297,7 @@ def __init__(
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))

def __call__(self, data) -> Dict[Hashable, np.ndarray]:
def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix):
d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"])
Expand All @@ -301,6 +309,8 @@ class RepeatChanneld(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`.
"""

backend = RepeatChannel.backend

def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -312,7 +322,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool
super().__init__(keys, allow_missing_keys)
self.repeater = RepeatChannel(repeats)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.repeater(d[key])
Expand Down
35 changes: 35 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor

__all__ = [
"moveaxis",
]


def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor:
if isinstance(x, torch.Tensor):
if hasattr(torch, "moveaxis"):
return torch.moveaxis(x, src, dst)
# moveaxis only available in pytorch as of 1.8.0
else:
# get original indices, remove desired index and insert it in new position
indices = list(range(x.ndim))
indices.pop(src)
indices.insert(dst, src)
return x.permute(indices)
elif isinstance(x, np.ndarray):
return np.moveaxis(x, src, dst)
raise RuntimeError()
17 changes: 11 additions & 6 deletions tests/test_add_channeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@
from parameterized import parameterized

from monai.transforms import AddChanneld
from tests.utils import TEST_NDARRAYS

TEST_CASE_1 = [
{"keys": ["img", "seg"]},
{"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]])},
(1, 2, 2),
]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
[
{"keys": ["img", "seg"]},
{"img": p(np.array([[0, 1], [1, 2]])), "seg": p(np.array([[0, 1], [1, 2]]))},
(1, 2, 2),
]
)


class TestAddChanneld(unittest.TestCase):
@parameterized.expand([TEST_CASE_1])
@parameterized.expand(TESTS)
def test_shape(self, input_param, input_data, expected_shape):
result = AddChanneld(**input_param)(input_data)
self.assertEqual(result["img"].shape, expected_shape)
Expand Down
22 changes: 14 additions & 8 deletions tests/test_as_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,29 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import AsChannelFirst
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)]

TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)]

TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)])
TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)])
TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)])


class TestAsChannelFirst(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_param, expected_shape):
test_data = np.random.randint(0, 2, size=[1, 2, 3, 4])
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_shape):
test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))
result = AsChannelFirst(**input_param)(test_data)
self.assertTupleEqual(result.shape, expected_shape)
if isinstance(test_data, torch.Tensor):
test_data = test_data.cpu().numpy()
expected = np.moveaxis(test_data, input_param["channel_dim"], 0)
assert_allclose(expected, result)


if __name__ == "__main__":
Expand Down
21 changes: 11 additions & 10 deletions tests/test_as_channel_firstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
from parameterized import parameterized

from monai.transforms import AsChannelFirstd
from tests.utils import TEST_NDARRAYS

TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]

TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]

TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)])


class TestAsChannelFirstd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_param, expected_shape):
@parameterized.expand(TESTS)
def test_shape(self, in_type, input_param, expected_shape):
test_data = {
"image": np.random.randint(0, 2, size=[1, 2, 3, 4]),
"label": np.random.randint(0, 2, size=[1, 2, 3, 4]),
"extra": np.random.randint(0, 2, size=[1, 2, 3, 4]),
"image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
"label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
"extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
}
result = AsChannelFirstd(**input_param)(test_data)
self.assertTupleEqual(result["image"].shape, expected_shape)
Expand Down
Loading

0 comments on commit 9266581

Please sign in to comment.