Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use double precision by default in apply_transforms computations #585

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions ants/registration/apply_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def apply_transforms(fixed, moving, transformlist,
interpolator='linear', imagetype=0,
whichtoinvert=None, compose=None,
defaultvalue=0, verbose=False, **kwargs):
defaultvalue=0, singleprecision=False, verbose=False, **kwargs):
"""
Apply a transform list to map an image from one domain to another.
In image registration, one computes mappings between (usually) pairs
Expand Down Expand Up @@ -67,6 +67,10 @@ def apply_transforms(fixed, moving, transformlist,
defaultvalue : scalar
Default voxel value for mappings outside the image domain.

singleprecision : boolean
if True, use float32 for computations. This is useful for reducing memory
usage for large datasets, at the cost of precision.

verbose : boolean
print command and run verbose application of transform.

Expand Down Expand Up @@ -102,16 +106,18 @@ def apply_transforms(fixed, moving, transformlist,

args = [fixed, moving, transformlist, interpolator]

output_pixel_type = 'float' if singleprecision else 'double'

if not isinstance(fixed, str):
if isinstance(fixed, iio.ANTsImage) and isinstance(moving, iio.ANTsImage):
for tl_path in transformlist:
if not os.path.exists(tl_path):
raise Exception('Transform %s does not exist' % tl_path)

inpixeltype = fixed.pixeltype
fixed = fixed.clone('float')
moving = moving.clone('float')
warpedmovout = moving.clone()
fixed = fixed.clone(output_pixel_type)
moving = moving.clone(output_pixel_type)
warpedmovout = moving.clone(output_pixel_type)
f = fixed
m = moving
if (moving.dimension == 4) and (fixed.dimension == 3) and (imagetype == 0):
Expand Down Expand Up @@ -165,7 +171,7 @@ def apply_transforms(fixed, moving, transformlist,
if verbose:
print(myargs)

processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(1), '-e', str(imagetype), '-f', str(defaultvalue)]
processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(int(singleprecision)), '-e', str(imagetype), '-f', str(defaultvalue)]
libfn = utils.get_lib_fn('antsApplyTransforms')
libfn(processed_args)

Expand All @@ -180,7 +186,7 @@ def apply_transforms(fixed, moving, transformlist,
else:
return 1
else:
args = args + ['-z', 1, '--float', 1, '-e', imagetype, '-f', defaultvalue]
args = args + ['-z', str(1), '--float', str(int(singleprecision)), '-e', imagetype, '-f', defaultvalue]
processed_args = utils._int_antsProcessArguments(args)
libfn = utils.get_lib_fn('antsApplyTransforms')
libfn(processed_args)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_core_ants_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_apply(self):
img = ants.image_read(ants.get_ants_data("r16")).clone('float')
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = tx.apply(data=img, reference=img, data_type='image')
img2 = tx.apply(data=img, reference=img, data_type='image')

def test_apply_to_point(self):
tx = ants.new_ants_transform()
Expand All @@ -99,14 +99,14 @@ def test_apply_to_vector(self):
tx = ants.new_ants_transform()
params = tx.parameters
tx.set_parameters(params*2)
pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6)
pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6)

def test_apply_to_image(self):
for ptype in self.pixeltypes:
img = ants.image_read(ants.get_ants_data("r16")).clone(ptype)
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = tx.apply_to_image(img, img)
img2 = tx.apply_to_image(img, img)


class TestModule_ants_transform(unittest.TestCase):
Expand Down Expand Up @@ -154,7 +154,8 @@ def test_apply_ants_transform(self):
img = ants.image_read(ants.get_ants_data("r16")).clone('float')
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image')
img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image')


def test_apply_ants_transform_to_point(self):
tx = ants.new_ants_transform()
Expand Down
12 changes: 11 additions & 1 deletion tests/test_registation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ def test_example(self):
fixed = ants.image_read(ants.get_ants_data("r16"))
moving = ants.image_read(ants.get_ants_data("r64"))
fixed = ants.resample_image(fixed, (64, 64), 1, 0)
moving = ants.resample_image(moving, (64, 64), 1, 0)
moving = ants.resample_image(moving, (128, 128), 1, 0)
mytx = ants.registration(fixed=fixed, moving=moving, type_of_transform="SyN")
mywarpedimage = ants.apply_transforms(
fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"]
)
self.assertEqual(mywarpedimage.pixeltype, moving.pixeltype)
self.assertTrue(ants.ants_image.image_physical_space_consistency(fixed, mywarpedimage,
0.0001, datatype = False))

# Call with float precision for transforms, but should still return input type
mywarpedimage2 = ants.apply_transforms(
fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"], singleprecision=True
)
self.assertEqual(mywarpedimage2.pixeltype, moving.pixeltype)
self.assertAlmostEqual(mywarpedimage.sum(), mywarpedimage2.sum(), places=3)

# bad interpolator
with self.assertRaises(Exception):
Expand Down