Skip to content

Commit

Permalink
Merge pull request #657 from ANTsX/support-channels-first-numpy
Browse files Browse the repository at this point in the history
ENH: support channels first for vector images
  • Loading branch information
Nicholas Cullen, PhD authored May 24, 2024
2 parents de706a8 + c743f9c commit 6d99b60
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
4 changes: 3 additions & 1 deletion ants/core/ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, pointer):
"""
self.pointer = pointer
self.channels_first = False
self._array = None

@property
Expand Down Expand Up @@ -254,7 +255,8 @@ def numpy(self, single_components=False):
"""
array = np.array(self.view(single_components=single_components), copy=True, dtype=self.dtype)
if self.has_components or (single_components == True):
array = np.rollaxis(array, 0, self.dimension+1)
if not self.channels_first:
array = np.rollaxis(array, 0, self.dimension+1)
return array

def astype(self, dtype):
Expand Down
12 changes: 8 additions & 4 deletions ants/utils/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ants.decorators import image_method


def merge_channels(image_list):
def merge_channels(image_list, channels_first=False):
"""
Merge channels of multiple scalar ANTsImage types into one
multi-channel ANTsImage
Expand All @@ -31,9 +31,11 @@ def merge_channels(image_list):
Example
-------
>>> import ants
>>> image = ants.image_read(ants.get_ants_data('r16'), 'float')
>>> image2 = ants.image_read(ants.get_ants_data('r16'), 'float')
>>> image = ants.image_read(ants.get_ants_data('r16'))
>>> image2 = ants.image_read(ants.get_ants_data('r16'))
>>> image3 = ants.merge_channels([image,image2])
>>> image3 = ants.merge_channels([image,image2], channels_first=True)
>>> image3.numpy()
>>> image3.components == 2
"""
inpixeltype = image_list[0].pixeltype
Expand All @@ -49,7 +51,9 @@ def merge_channels(image_list):
libfn = get_lib_fn('mergeChannels')
image_ptr = libfn([image.pointer for image in image_list])

return ants.from_pointer(image_ptr)
image = ants.from_pointer(image_ptr)
image.channels_first = channels_first
return image

@image_method
def split_channels(image):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,16 @@ def test_hausdorff_distance(self):
s16 = ants.kmeans_segmentation( r16, 3 )['segmentation']
s64 = ants.kmeans_segmentation( r64, 3 )['segmentation']
stats = ants.hausdorff_distance(s16, s64)

def test_channels_first(self):
import ants
image = ants.image_read(ants.get_ants_data('r16'))
image2 = ants.image_read(ants.get_ants_data('r16'))
img3 = ants.merge_channels([image,image2])
img4 = ants.merge_channels([image,image2], channels_first=True)

self.assertTrue(np.allclose(img3.numpy()[:,:,0], img4.numpy()[0,:,:]))
self.assertTrue(np.allclose(img3.numpy()[:,:,1], img4.numpy()[1,:,:]))


if __name__ == "__main__":
Expand Down

0 comments on commit 6d99b60

Please sign in to comment.