Skip to content

Commit

Permalink
2231 Enhance tensor transforms (#2966)
Browse files Browse the repository at this point in the history
* [DLMED] enhance tensor transforms

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix tests

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix mypy

Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored Sep 17, 2021
1 parent dc3e263 commit ecbb03b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z
self.mode = mode
self.img_t: torch.Tensor = torch.tensor(0.0)

def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].
Expand All @@ -969,7 +969,9 @@ def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
# add one to transform axis because a batch axis will be added at dimension 0
savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode)
# convert to Tensor and add Batch axis expected by HilbertTransform
out: torch.Tensor = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0)
smoothed = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0)
out, *_ = convert_to_dst_type(smoothed, dst=img)

return out


Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class Zoom(Transform):
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
backend = [TransformBackends.TORCH]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_savitzky_golay_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ class TestSavitzkyGolaySmooth(unittest.TestCase):
@parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH])
def test_value(self, arguments, image, expected_data, atol):
for p in TEST_NDARRAYS:
result = SavitzkyGolaySmooth(**arguments)(p(image))
result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32)))
torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol)


class TestSavitzkyGolaySmoothREP(unittest.TestCase):
@parameterized.expand([TEST_CASE_SINGLE_VALUE_REP])
def test_value(self, arguments, image, expected_data, atol):
for p in TEST_NDARRAYS:
result = SavitzkyGolaySmooth(**arguments)(p(image))
result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32)))
torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol)


Expand Down

0 comments on commit ecbb03b

Please sign in to comment.