Skip to content

Commit ec72d31

Browse files
authored
Merge branch 'dev' into 4569-loading-rng
2 parents 811c455 + 8087dc5 commit ec72d31

File tree

7 files changed

+304
-0
lines changed

7 files changed

+304
-0
lines changed

docs/source/transforms.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,12 @@ Post-processing
578578
.. autoclass:: ProbNMS
579579
:members:
580580

581+
`SobelGradients`
582+
""""""""""""""""
583+
.. autoclass:: SobelGradients
584+
:members:
585+
:special-members: __call__
586+
581587
`VoteEnsemble`
582588
""""""""""""""
583589
.. autoclass:: VoteEnsemble
@@ -1593,6 +1599,14 @@ Post-processing (Dict)
15931599
:members:
15941600
:special-members: __call__
15951601

1602+
1603+
`SobelGradientsd`
1604+
"""""""""""""""""
1605+
.. autoclass:: SobelGradientsd
1606+
:members:
1607+
:special-members: __call__
1608+
1609+
15961610
Spatial (Dict)
15971611
^^^^^^^^^^^^^^
15981612

monai/networks/layers/simplelayers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
290290
else:
291291
# even-sized kernels are not supported
292292
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
293+
elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
294+
# even-sized kernels are not supported
295+
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
293296

294297
if "stride" not in kwargs:
295298
kwargs["stride"] = 1

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@
268268
MeanEnsemble,
269269
ProbNMS,
270270
RemoveSmallObjects,
271+
SobelGradients,
271272
VoteEnsemble,
272273
)
273274
from .post.dictionary import (
@@ -307,6 +308,9 @@
307308
SaveClassificationD,
308309
SaveClassificationd,
309310
SaveClassificationDict,
311+
SobelGradientsd,
312+
SobelGradientsD,
313+
SobelGradientsDict,
310314
VoteEnsembleD,
311315
VoteEnsembled,
312316
VoteEnsembleDict,

monai/transforms/post/array.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"LabelToContour",
5555
"MeanEnsemble",
5656
"ProbNMS",
57+
"SobelGradients",
5758
"VoteEnsemble",
5859
"Invert",
5960
]
@@ -852,3 +853,53 @@ def __call__(self, data):
852853
inverted = self.transform.inverse(data)
853854
inverted = self.post_func(inverted.to(self.device))
854855
return inverted
856+
857+
858+
class SobelGradients(Transform):
859+
"""Calculate Sobel horizontal and vertical gradients
860+
861+
Args:
862+
kernel_size: the size of the Sobel kernel. Defaults to 3.
863+
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
864+
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
865+
device: the device to create the kernel on. Defaults to `"cpu"`.
866+
867+
"""
868+
869+
backend = [TransformBackends.TORCH]
870+
871+
def __init__(
872+
self,
873+
kernel_size: int = 3,
874+
padding: Union[int, str] = "same",
875+
dtype: torch.dtype = torch.float32,
876+
device: Union[torch.device, int, str] = "cpu",
877+
) -> None:
878+
super().__init__()
879+
self.kernel: torch.Tensor = self._get_kernel(kernel_size, dtype, device)
880+
self.padding = padding
881+
882+
def _get_kernel(self, size, dtype, device) -> torch.Tensor:
883+
if size % 2 == 0:
884+
raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.")
885+
if not dtype.is_floating_point:
886+
raise ValueError(f"`dtype` for Sobel kernel should be floating point. {dtype} was given.")
887+
888+
numerator: torch.Tensor = torch.arange(
889+
-size // 2 + 1, size // 2 + 1, dtype=dtype, device=device, requires_grad=False
890+
).expand(size, size)
891+
denominator = numerator * numerator
892+
denominator = denominator + denominator.T
893+
denominator[:, size // 2] = 1.0 # to avoid division by zero
894+
kernel = numerator / denominator
895+
return kernel
896+
897+
def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
898+
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
899+
kernel_v = self.kernel.to(image_tensor.device)
900+
kernel_h = kernel_v.T
901+
grad_v = apply_filter(image_tensor, kernel_v, padding=self.padding)
902+
grad_h = apply_filter(image_tensor, kernel_h, padding=self.padding)
903+
grad = torch.cat([grad_h, grad_v])
904+
905+
return grad

monai/transforms/post/dictionary.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
MeanEnsemble,
3737
ProbNMS,
3838
RemoveSmallObjects,
39+
SobelGradients,
3940
VoteEnsemble,
4041
)
4142
from monai.transforms.transform import MapTransform
@@ -795,6 +796,44 @@ def get_saver(self):
795796
return self.saver
796797

797798

799+
class SobelGradientsd(MapTransform):
800+
"""Calculate Sobel horizontal and vertical gradients.
801+
802+
Args:
803+
keys: keys of the corresponding items to model output.
804+
kernel_size: the size of the Sobel kernel. Defaults to 3.
805+
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
806+
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
807+
device: the device to create the kernel on. Defaults to `"cpu"`.
808+
new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
809+
key intact. By default not prefix is set and the corresponding array to the key will be replaced.
810+
allow_missing_keys: don't raise exception if key is missing.
811+
812+
"""
813+
814+
def __init__(
815+
self,
816+
keys: KeysCollection,
817+
kernel_size: int = 3,
818+
padding: Union[int, str] = "same",
819+
dtype: torch.dtype = torch.float32,
820+
device: Union[torch.device, int, str] = "cpu",
821+
new_key_prefix: Optional[str] = None,
822+
allow_missing_keys: bool = False,
823+
) -> None:
824+
super().__init__(keys, allow_missing_keys)
825+
self.transform = SobelGradients(kernel_size=kernel_size, padding=padding, dtype=dtype, device=device)
826+
self.new_key_prefix = new_key_prefix
827+
828+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
829+
d = dict(data)
830+
for key in self.key_iterator(d):
831+
new_key = key if self.new_key_prefix is None else self.new_key_prefix + key
832+
d[new_key] = self.transform(d[key])
833+
834+
return d
835+
836+
798837
ActivationsD = ActivationsDict = Activationsd
799838
AsDiscreteD = AsDiscreteDict = AsDiscreted
800839
FillHolesD = FillHolesDict = FillHolesd
@@ -808,3 +847,4 @@ def get_saver(self):
808847
SaveClassificationD = SaveClassificationDict = SaveClassificationd
809848
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
810849
EnsembleD = EnsembleDict = Ensembled
850+
SobelGradientsD = SobelGradientsDict = SobelGradientsd

tests/test_sobel_gradient.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.transforms import SobelGradients
18+
from tests.utils import assert_allclose
19+
20+
IMAGE = torch.zeros(1, 1, 16, 16, dtype=torch.float32)
21+
IMAGE[0, 0, 8, :] = 1
22+
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
23+
OUTPUT_3x3[0, 7, :] = 2.0
24+
OUTPUT_3x3[0, 9, :] = -2.0
25+
OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5
26+
OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5
27+
OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5
28+
OUTPUT_3x3[1, 8, 0] = 1.0
29+
OUTPUT_3x3[1, 8, -1] = -1.0
30+
OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5
31+
OUTPUT_3x3 = OUTPUT_3x3.unsqueeze(1)
32+
33+
TEST_CASE_0 = [IMAGE, {"kernel_size": 3, "dtype": torch.float32}, OUTPUT_3x3]
34+
TEST_CASE_1 = [IMAGE, {"kernel_size": 3, "dtype": torch.float64}, OUTPUT_3x3]
35+
36+
TEST_CASE_KERNEL_0 = [
37+
{"kernel_size": 3, "dtype": torch.float64},
38+
torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64),
39+
]
40+
TEST_CASE_KERNEL_1 = [
41+
{"kernel_size": 5, "dtype": torch.float64},
42+
torch.tensor(
43+
[
44+
[-0.25, -0.2, 0.0, 0.2, 0.25],
45+
[-0.4, -0.5, 0.0, 0.5, 0.4],
46+
[-0.5, -1.0, 0.0, 1.0, 0.5],
47+
[-0.4, -0.5, 0.0, 0.5, 0.4],
48+
[-0.25, -0.2, 0.0, 0.2, 0.25],
49+
],
50+
dtype=torch.float64,
51+
),
52+
]
53+
TEST_CASE_KERNEL_2 = [
54+
{"kernel_size": 7, "dtype": torch.float64},
55+
torch.tensor(
56+
[
57+
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
58+
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
59+
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
60+
[-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0],
61+
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
62+
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
63+
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
64+
],
65+
dtype=torch.float64,
66+
),
67+
]
68+
TEST_CASE_ERROR_0 = [{"kernel_size": 2, "dtype": torch.float32}]
69+
70+
71+
class SobelGradientTests(unittest.TestCase):
72+
backend = None
73+
74+
@parameterized.expand([TEST_CASE_0])
75+
def test_sobel_gradients(self, image, arguments, expected_grad):
76+
sobel = SobelGradients(**arguments)
77+
grad = sobel(image)
78+
assert_allclose(grad, expected_grad)
79+
80+
@parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2])
81+
def test_sobel_kernels(self, arguments, expected_kernel):
82+
sobel = SobelGradients(**arguments)
83+
self.assertTrue(sobel.kernel.dtype == expected_kernel.dtype)
84+
assert_allclose(sobel.kernel, expected_kernel)
85+
86+
@parameterized.expand([TEST_CASE_ERROR_0])
87+
def test_sobel_gradients_error(self, arguments):
88+
with self.assertRaises(ValueError):
89+
SobelGradients(**arguments)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

tests/test_sobel_gradientd.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.transforms import SobelGradientsd
18+
from tests.utils import assert_allclose
19+
20+
IMAGE = torch.zeros(1, 1, 16, 16, dtype=torch.float32)
21+
IMAGE[0, 0, 8, :] = 1
22+
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
23+
OUTPUT_3x3[0, 7, :] = 2.0
24+
OUTPUT_3x3[0, 9, :] = -2.0
25+
OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5
26+
OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5
27+
OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5
28+
OUTPUT_3x3[1, 8, 0] = 1.0
29+
OUTPUT_3x3[1, 8, -1] = -1.0
30+
OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5
31+
OUTPUT_3x3 = OUTPUT_3x3.unsqueeze(1)
32+
33+
TEST_CASE_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float32}, {"image": OUTPUT_3x3}]
34+
TEST_CASE_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, {"image": OUTPUT_3x3}]
35+
TEST_CASE_2 = [
36+
{"image": IMAGE},
37+
{"keys": "image", "kernel_size": 3, "dtype": torch.float32, "new_key_prefix": "sobel_"},
38+
{"sobel_image": OUTPUT_3x3},
39+
]
40+
41+
TEST_CASE_KERNEL_0 = [
42+
{"keys": "image", "kernel_size": 3, "dtype": torch.float64},
43+
torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64),
44+
]
45+
TEST_CASE_KERNEL_1 = [
46+
{"keys": "image", "kernel_size": 5, "dtype": torch.float64},
47+
torch.tensor(
48+
[
49+
[-0.25, -0.2, 0.0, 0.2, 0.25],
50+
[-0.4, -0.5, 0.0, 0.5, 0.4],
51+
[-0.5, -1.0, 0.0, 1.0, 0.5],
52+
[-0.4, -0.5, 0.0, 0.5, 0.4],
53+
[-0.25, -0.2, 0.0, 0.2, 0.25],
54+
],
55+
dtype=torch.float64,
56+
),
57+
]
58+
TEST_CASE_KERNEL_2 = [
59+
{"keys": "image", "kernel_size": 7, "dtype": torch.float64},
60+
torch.tensor(
61+
[
62+
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
63+
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
64+
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
65+
[-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0],
66+
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
67+
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
68+
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
69+
],
70+
dtype=torch.float64,
71+
),
72+
]
73+
TEST_CASE_ERROR_0 = [{"keys": "image", "kernel_size": 2, "dtype": torch.float32}]
74+
75+
76+
class SobelGradientTests(unittest.TestCase):
77+
backend = None
78+
79+
@parameterized.expand([TEST_CASE_0])
80+
def test_sobel_gradients(self, image_dict, arguments, expected_grad):
81+
sobel = SobelGradientsd(**arguments)
82+
grad = sobel(image_dict)
83+
key = "image" if "new_key_prefix" not in arguments else arguments["new_key_prefix"] + arguments["keys"]
84+
assert_allclose(grad[key], expected_grad[key])
85+
86+
@parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2])
87+
def test_sobel_kernels(self, arguments, expected_kernel):
88+
sobel = SobelGradientsd(**arguments)
89+
self.assertTrue(sobel.transform.kernel.dtype == expected_kernel.dtype)
90+
assert_allclose(sobel.transform.kernel, expected_kernel)
91+
92+
@parameterized.expand([TEST_CASE_ERROR_0])
93+
def test_sobel_gradients_error(self, arguments):
94+
with self.assertRaises(ValueError):
95+
SobelGradientsd(**arguments)
96+
97+
98+
if __name__ == "__main__":
99+
unittest.main()

0 commit comments

Comments
 (0)