diff --git a/niworkflows/interfaces/nibabel.py b/niworkflows/interfaces/nibabel.py index 9364636279b..9c660d1d4c5 100644 --- a/niworkflows/interfaces/nibabel.py +++ b/niworkflows/interfaces/nibabel.py @@ -1,27 +1,34 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Nibabel-based interfaces.""" +from pathlib import Path import numpy as np import nibabel as nb from nipype import logging from nipype.utils.filemanip import fname_presuffix from nipype.interfaces.base import ( - traits, TraitedSpec, BaseInterfaceInputSpec, File, - SimpleInterface + traits, + TraitedSpec, + BaseInterfaceInputSpec, + File, + SimpleInterface, + OutputMultiObject, + InputMultiObject, ) -IFLOGGER = logging.getLogger('nipype.interface') +IFLOGGER = logging.getLogger("nipype.interface") class _ApplyMaskInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='an image') - in_mask = File(exists=True, mandatory=True, desc='a mask') - threshold = traits.Float(0.5, usedefault=True, - desc='a threshold to the mask, if it is nonbinary') + in_file = File(exists=True, mandatory=True, desc="an image") + in_mask = File(exists=True, mandatory=True, desc="a mask") + threshold = traits.Float( + 0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary" + ) class _ApplyMaskOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='masked file') + out_file = File(exists=True, desc="masked file") class ApplyMask(SimpleInterface): @@ -35,8 +42,9 @@ def _run_interface(self, runtime): msknii = nb.load(self.inputs.in_mask) msk = msknii.get_fdata() > self.inputs.threshold - self._results['out_file'] = fname_presuffix( - self.inputs.in_file, suffix='_masked', newpath=runtime.cwd) + self._results["out_file"] = fname_presuffix( + self.inputs.in_file, suffix="_masked", newpath=runtime.cwd + ) if img.dataobj.shape[:3] != msk.shape: raise ValueError("Image and mask sizes do not match.") @@ -48,19 +56,18 @@ def _run_interface(self, runtime): msk = msk[..., np.newaxis] masked = img.__class__(img.dataobj * msk, None, img.header) - masked.to_filename(self._results['out_file']) + masked.to_filename(self._results["out_file"]) return runtime class _BinarizeInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='input image') - thresh_low = traits.Float(mandatory=True, - desc='non-inclusive lower threshold') + in_file = File(exists=True, mandatory=True, desc="input image") + thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold") class _BinarizeOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='masked file') - out_mask = File(exists=True, desc='output mask') + out_file = File(exists=True, desc="masked file") + out_mask = File(exists=True, desc="output mask") class Binarize(SimpleInterface): @@ -72,20 +79,98 @@ class Binarize(SimpleInterface): def _run_interface(self, runtime): img = nb.load(self.inputs.in_file) - self._results['out_file'] = fname_presuffix( - self.inputs.in_file, suffix='_masked', newpath=runtime.cwd) - self._results['out_mask'] = fname_presuffix( - self.inputs.in_file, suffix='_mask', newpath=runtime.cwd) + self._results["out_file"] = fname_presuffix( + self.inputs.in_file, suffix="_masked", newpath=runtime.cwd + ) + self._results["out_mask"] = fname_presuffix( + self.inputs.in_file, suffix="_mask", newpath=runtime.cwd + ) data = img.get_fdata() mask = data > self.inputs.thresh_low data[~mask] = 0.0 masked = img.__class__(data, img.affine, img.header) - masked.to_filename(self._results['out_file']) + masked.to_filename(self._results["out_file"]) - img.header.set_data_dtype('uint8') - maskimg = img.__class__(mask.astype('uint8'), img.affine, - img.header) - maskimg.to_filename(self._results['out_mask']) + img.header.set_data_dtype("uint8") + maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header) + maskimg.to_filename(self._results["out_mask"]) return runtime + + +class _SplitSeriesInputSpec(BaseInterfaceInputSpec): + in_file = File(exists=True, mandatory=True, desc="input 4d image") + + +class _SplitSeriesOutputSpec(TraitedSpec): + out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images") + + +class SplitSeries(SimpleInterface): + """Split a 4D dataset along the last dimension into a series of 3D volumes.""" + + input_spec = _SplitSeriesInputSpec + output_spec = _SplitSeriesOutputSpec + + def _run_interface(self, runtime): + in_file = self.inputs.in_file + img = nb.load(in_file) + extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,) + if len(extra_dims) != 1: + raise ValueError(f"Invalid shape {'x'.join(str(s) for s in img.shape)}") + img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims), + img.affine, img.header) + + self._results["out_files"] = [] + for i, img_3d in enumerate(nb.four_to_three(img)): + out_file = str( + Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute() + ) + img_3d.to_filename(out_file) + self._results["out_files"].append(out_file) + + return runtime + + +class _MergeSeriesInputSpec(BaseInterfaceInputSpec): + in_files = InputMultiObject( + File(exists=True, mandatory=True, desc="input list of 3d images") + ) + allow_4D = traits.Bool( + True, usedefault=True, desc="whether 4D images are allowed to be concatenated" + ) + + +class _MergeSeriesOutputSpec(TraitedSpec): + out_file = File(exists=True, desc="output 4d image") + + +class MergeSeries(SimpleInterface): + """Merge a series of 3D volumes along the last dimension into a single 4D image.""" + + input_spec = _MergeSeriesInputSpec + output_spec = _MergeSeriesOutputSpec + + def _run_interface(self, runtime): + nii_list = [] + for f in self.inputs.in_files: + filenii = nb.squeeze_image(nb.load(f)) + ndim = filenii.dataobj.ndim + if ndim == 3: + nii_list.append(filenii) + continue + elif self.inputs.allow_4D and ndim == 4: + nii_list += nb.four_to_three(filenii) + continue + else: + raise ValueError( + "Input image has an incorrect number of dimensions" f" ({ndim})." + ) + + img_4d = nb.concat_images(nii_list) + out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged") + img_4d.to_filename(out_file) + + self._results["out_file"] = out_file + return runtime diff --git a/niworkflows/interfaces/tests/test_nibabel.py b/niworkflows/interfaces/tests/test_nibabel.py index abdab406fcf..3ab8c4ffe47 100644 --- a/niworkflows/interfaces/tests/test_nibabel.py +++ b/niworkflows/interfaces/tests/test_nibabel.py @@ -4,7 +4,7 @@ import nibabel as nb import pytest -from ..nibabel import Binarize, ApplyMask +from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries def test_Binarize(tmp_path): @@ -14,10 +14,10 @@ def test_Binarize(tmp_path): mask = np.zeros((20, 20, 20), dtype=bool) mask[5:15, 5:15, 5:15] = bool - data = np.zeros_like(mask, dtype='float32') + data = np.zeros_like(mask, dtype="float32") data[mask] = np.random.gamma(2, size=mask.sum()) - in_file = tmp_path / 'input.nii.gz' + in_file = tmp_path / "input.nii.gz" nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file)) binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run() @@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path): mask[8:11, 8:11, 8:11] = 1.0 # Test the 3D - in_file = tmp_path / 'input3D.nii.gz' + in_file = tmp_path / "input3D.nii.gz" nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file)) - in_mask = tmp_path / 'mask.nii.gz' + in_mask = tmp_path / "mask.nii.gz" nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask)) masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run() - assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 + assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run() - assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 + assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 data4d = np.stack((data, 2 * data, 3 * data), axis=-1) # Test the 4D case - in_file4d = tmp_path / 'input4D.nii.gz' + in_file4d = tmp_path / "input4D.nii.gz" nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d)) - masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run() - assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6 + masked1 = ApplyMask( + in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4 + ).run() + assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6 - masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run() - assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6 + masked1 = ApplyMask( + in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6 + ).run() + assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6 # Test errors nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask)) @@ -69,3 +73,58 @@ def test_ApplyMask(tmp_path): ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run() with pytest.raises(ValueError): ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run() + + +@pytest.mark.parametrize("shape,exp_n", [ + ((20, 20, 20, 15), 15), + ((20, 20, 20), 1), + ((20, 20, 20, 1), 1), + ((20, 20, 20, 1, 3), 3), + ((20, 20, 20, 3, 1), 3), + ((20, 20, 20, 1, 3, 3), -1), + ((20, 1, 20, 15), 15), + ((20, 1, 20), 1), + ((20, 1, 20, 1), 1), + ((20, 1, 20, 1, 3), 3), + ((20, 1, 20, 3, 1), 3), + ((20, 1, 20, 1, 3, 3), -1), +]) +def test_SplitSeries(tmp_path, shape, exp_n): + """Test 4-to-3 NIfTI split interface.""" + os.chdir(tmp_path) + + in_file = str(tmp_path / "input.nii.gz") + nb.Nifti1Image(np.ones(shape, dtype=float), np.eye(4), None).to_filename(in_file) + + _interface = SplitSeries(in_file=in_file) + if exp_n > 0: + split = _interface.run() + n = int(isinstance(split.outputs.out_files, str)) or len(split.outputs.out_files) + assert n == exp_n + else: + with pytest.raises(ValueError): + _interface.run() + + +def test_MergeSeries(tmp_path): + """Test 3-to-4 NIfTI concatenation interface.""" + os.chdir(str(tmp_path)) + + in_file = tmp_path / "input3D.nii.gz" + nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename( + str(in_file) + ) + + merge = MergeSeries(in_files=[str(in_file)] * 5).run() + assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5) + + in_4D = tmp_path / "input4D.nii.gz" + nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename( + str(in_4D) + ) + + merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run() + assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5) + + with pytest.raises(ValueError): + MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()