Skip to content

Commit

Permalink
support slicing vector images
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 23, 2024
1 parent 00c264a commit 2569829
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ants/ops/slice_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__all__ = ['slice_image']

import math

import numpy as np
import ants
from ants.decorators import image_method
from ants.internal import get_lib_fn
Expand Down Expand Up @@ -32,6 +32,13 @@ def slice_image(image, axis, idx, collapse_strategy=0):
>>> mni = ants.image_read(ants.get_data('mni'))
>>> mni2 = ants.slice_image(mni, axis=1, idx=100)
"""
if image.has_components:
ilist = ants.split_channels(image)
if image.dimension == 2:
return np.stack(tuple([i.slice_image(axis, idx, collapse_strategy) for i in ilist]), axis=-1)
else:
return ants.merge_channels([i.slice_image(axis, idx, collapse_strategy) for i in ilist])

if axis == -1:
axis = image.dimension - 1

Expand Down
53 changes: 53 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,29 @@ def test_slice_image_2d(self):
img3 = ants.slice_image(img, 1, 100)
self.assertTrue(np.allclose(img2, img3))

def test_slice_image_2d_vector(self):
img0 = ants.image_read(ants.get_ants_data('r16'))
img = ants.merge_channels([img0,img0,img0])

img2 = ants.slice_image(img, 0, 100)
self.assertTrue(isinstance(img2, np.ndarray))

img2 = ants.slice_image(img, 1, 100)
self.assertTrue(isinstance(img2, np.ndarray))

with self.assertRaises(Exception):
img2 = ants.slice_image(img, 2, 100)

with self.assertRaises(Exception):
img2 = ants.slice_image(img, -3, 100)

with self.assertRaises(Exception):
img2 = ants.slice_image(img, 4, 100)

img2 = ants.slice_image(img, -1, 100)
img3 = ants.slice_image(img, 1, 100)
self.assertTrue(np.allclose(img2, img3))

def test_slice_image_3d(self):
"""
Test that resampling an image doesnt cause the resampled
Expand Down Expand Up @@ -76,6 +98,37 @@ def test_slice_image_3d(self):
img3 = ants.slice_image(img, 2, 100)
self.assertTrue(ants.allclose(img2, img3))

def test_slice_image_3d_vector(self):
"""
Test that resampling an image doesnt cause the resampled
image to have NaNs - previously caused by resampling an
image of type DOUBLE
"""
img0 = ants.image_read(ants.get_ants_data('mni'))
img = ants.merge_channels([img0,img0,img0])

img2 = ants.slice_image(img, 0, 100)
self.assertEqual(img2.dimension, 2)

img2 = ants.slice_image(img, 1, 100)
self.assertEqual(img2.dimension, 2)

img2 = ants.slice_image(img, 2, 100)
self.assertEqual(img2.dimension, 2)

img2 = ants.slice_image(img.clone('unsigned int'), 2, 100)
self.assertEqual(img2.dimension, 2)

with self.assertRaises(Exception):
img2 = ants.slice_image(img, 3, 100)

with self.assertRaises(Exception):
img2 = ants.slice_image(img, 2, 100, collapse_strategy=23)

img2 = ants.slice_image(img, -1, 100)
img3 = ants.slice_image(img, 2, 100)
self.assertTrue(ants.allclose(img2, img3))


if __name__ == '__main__':
run_tests()

0 comments on commit 2569829

Please sign in to comment.