diff --git a/nitransforms/base.py b/nitransforms/base.py index a264ed94..2932daef 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -270,7 +270,10 @@ def apply( if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) - data = np.asanyarray(spatialimage.dataobj) + data = np.asanyarray( + spatialimage.dataobj, + dtype=spatialimage.get_data_dtype() + ) output_dtype = output_dtype or data.dtype targets = ImageGrid(spatialimage).index( # data should be an image _as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim) @@ -288,9 +291,11 @@ def apply( if isinstance(_ref, ImageGrid): # If reference is grid, reshape moved = spatialimage.__class__( - resampled.reshape(_ref.shape), _ref.affine, spatialimage.header + resampled.reshape(_ref.shape).astype(output_dtype), + _ref.affine, + spatialimage.header ) - moved.header.set_data_dtype(output_dtype) + moved.set_data_dtype(output_dtype) return moved return resampled diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index 8506723f..4940ac4f 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -88,19 +88,26 @@ def _to_hdf5(klass, x5_root): monkeypatch.setattr(TransformBase, "_to_hdf5", _to_hdf5) fname = testdata_path / "someones_anatomy.nii.gz" + img = nb.load(fname) + imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype()) + # Test identity transform xfm = TransformBase() xfm.reference = fname assert xfm.ndim == 3 moved = xfm.apply(fname, order=0) - assert np.all(nb.load(str(fname)).get_fdata() == moved.get_fdata()) + assert np.all( + imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()) + ) # Test identity transform - setting reference xfm = TransformBase() xfm.reference = fname assert xfm.ndim == 3 moved = xfm.apply(str(fname), reference=fname, order=0) - assert np.all(nb.load(str(fname)).get_fdata() == moved.get_fdata()) + assert np.all( + imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()) + ) # Test applying to Gifti gii = nb.gifti.GiftiImage( diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 48528896..7f9c4d25 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -216,6 +216,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) assert np.sqrt((diff ** 2).mean()) < RMSE_TOL + brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) cmd = APPLY_LINEAR_CMD[sw_tool]( transform=os.path.abspath(xfm_fname), @@ -224,23 +225,25 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient resampled=os.path.abspath("resampled.nii.gz"), ) - brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) - exit_code = check_call([cmd], shell=True) assert exit_code == 0 sw_moved = nb.load("resampled.nii.gz") sw_moved.set_data_dtype(img.get_data_dtype()) nt_moved = xfm.apply(img, order=0) - diff = (sw_moved.get_fdata() - nt_moved.get_fdata()) - diff[~brainmask] = 0.0 - diff[np.abs(diff) < 1e-3] = 0 + diff = ( + np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) + - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + ) # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff ** 2).mean()) < RMSE_TOL + assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL nt_moved = xfm.apply("img.nii.gz", order=0) - diff = sw_moved.get_fdata() - nt_moved.get_fdata() + diff = ( + np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) + - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + ) # A certain tolerance is necessary because of resampling at borders assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index 1977ba21..6dee540e 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -11,7 +11,7 @@ from ..manip import load as _load, TransformChain from ..linear import Affine from .test_nonlinear import ( - TESTS_BORDER_TOLERANCE, + RMSE_TOL, APPLY_NONLINEAR_CMD, ) @@ -38,7 +38,11 @@ def test_itk_h5(tmp_path, testdata_path): # Then apply the transform and cross-check with software cmd = APPLY_NONLINEAR_CMD["itk"]( - transform=xfm_fname, reference=ref_fname, moving=img_fname + transform=xfm_fname, + reference=ref_fname, + moving=img_fname, + output="resampled.nii.gz", + extra="", ) # skip test if command is not available on host @@ -54,7 +58,7 @@ def test_itk_h5(tmp_path, testdata_path): nt_moved.to_filename("nt_resampled.nii.gz") diff = sw_moved.get_fdata() - nt_moved.get_fdata() # A certain tolerance is necessary because of resampling at borders - assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL @pytest.mark.parametrize("ext0", ["lta", "tfm"]) diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 1f5e8b5a..b0cacc3d 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -13,18 +13,18 @@ from ..io.itk import ITKDisplacementsField -TESTS_BORDER_TOLERANCE = 0.05 +RMSE_TOL = 0.05 APPLY_NONLINEAR_CMD = { "itk": """\ antsApplyTransforms -d 3 -r {reference} -i {moving} \ --o resampled.nii.gz -n NearestNeighbor -t {transform} --float\ +-o {output} -n NearestNeighbor -t {transform} {extra}\ """.format, "afni": """\ 3dNwarpApply -nwarp {transform} -source {moving} \ --master {reference} -interp NN -prefix resampled.nii.gz +-master {reference} -interp NN -prefix {output} {extra}\ """.format, 'fsl': """\ -applywarp -i {moving} -r {reference} -o resampled.nii.gz \ +applywarp -i {moving} -r {reference} -o {output} {extra}\ -w {transform} --interp=nn""".format, } @@ -56,13 +56,23 @@ def test_itk_disp_load_intent(): @pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"]) @pytest.mark.parametrize("sw_tool", ["itk", "afni"]) @pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)]) -def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis): +def test_displacements_field1( + tmp_path, + get_testdata, + get_testmask, + image_orientation, + sw_tool, + axis, +): """Check a translation-only field on one or more axes, different image orientations.""" if (image_orientation, sw_tool) == ("oblique", "afni") and axis in ((1, 2), (0, 1, 2)): pytest.skip("AFNI Deoblique unsupported.") os.chdir(str(tmp_path)) nii = get_testdata[image_orientation] + msk = get_testmask[image_orientation] nii.to_filename("reference.nii.gz") + msk.to_filename("mask.nii.gz") + fieldmap = np.zeros( (*nii.shape[:3], 1, 3) if sw_tool != "fsl" else (*nii.shape[:3], 3), dtype="float32", @@ -83,8 +93,10 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool # Then apply the transform and cross-check with software cmd = APPLY_NONLINEAR_CMD[sw_tool]( transform=os.path.abspath(xfm_fname), - reference=tmp_path / "reference.nii.gz", - moving=tmp_path / "reference.nii.gz", + reference=tmp_path / "mask.nii.gz", + moving=tmp_path / "mask.nii.gz", + output=tmp_path / "resampled_brainmask.nii.gz", + extra="--output-data-type uchar" if sw_tool == "itk" else "", ) # skip test if command is not available on host @@ -92,15 +104,39 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool if not shutil.which(exe): pytest.skip("Command {} not found on host".format(exe)) + # resample mask + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved_mask = nb.load("resampled_brainmask.nii.gz") + nt_moved_mask = xfm.apply(msk, order=0) + nt_moved_mask.set_data_dtype(msk.get_data_dtype()) + diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) + + assert np.sqrt((diff ** 2).mean()) < RMSE_TOL + brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD[sw_tool]( + transform=os.path.abspath(xfm_fname), + reference=tmp_path / "reference.nii.gz", + moving=tmp_path / "reference.nii.gz", + output=tmp_path / "resampled.nii.gz", + extra="--output-data-type uchar" if sw_tool == "itk" else "" + ) + exit_code = check_call([cmd], shell=True) assert exit_code == 0 sw_moved = nb.load("resampled.nii.gz") nt_moved = xfm.apply(nii, order=0) nt_moved.to_filename("nt_resampled.nii.gz") - diff = sw_moved.get_fdata() - nt_moved.get_fdata() + sw_moved.set_data_dtype(nt_moved.get_data_dtype()) + diff = ( + np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) + - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + ) # A certain tolerance is necessary because of resampling at borders - assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE + assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL @pytest.mark.parametrize("sw_tool", ["itk", "afni"]) @@ -116,7 +152,11 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool): # Then apply the transform and cross-check with software cmd = APPLY_NONLINEAR_CMD[sw_tool]( - transform=xfm_fname, reference=img_fname, moving=img_fname + transform=xfm_fname, + reference=img_fname, + moving=img_fname, + output="resampled.nii.gz", + extra="", ) # skip test if command is not available on host @@ -130,6 +170,10 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool): nt_moved = xfm.apply(img_fname, order=0) nt_moved.to_filename("nt_resampled.nii.gz") - diff = sw_moved.get_fdata() - nt_moved.get_fdata() + sw_moved.set_data_dtype(nt_moved.get_data_dtype()) + diff = ( + np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) + - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) + ) # A certain tolerance is necessary because of resampling at borders - assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE + assert np.sqrt((diff ** 2).mean()) < RMSE_TOL