diff --git a/niworkflows/interfaces/images.py b/niworkflows/interfaces/images.py index 90f608b3d10..2da065a2b2a 100644 --- a/niworkflows/interfaces/images.py +++ b/niworkflows/interfaces/images.py @@ -12,7 +12,7 @@ from nipype.utils.filemanip import fname_presuffix from nipype.interfaces.base import ( traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, - File, InputMultiPath, OutputMultiPath) + File, InputMultiPath, OutputMultiPath, isdefined) from nipype.interfaces import fsl LOGGER = logging.getLogger('nipype.interface') @@ -21,9 +21,11 @@ class _IntraModalMergeInputSpec(BaseInterfaceInputSpec): in_files = InputMultiPath(File(exists=True), mandatory=True, desc='input files') + in_mask = File(exists=True, desc='input mask for grand mean scaling') hmc = traits.Bool(True, usedefault=True) zero_based_avg = traits.Bool(True, usedefault=True) to_ras = traits.Bool(True, usedefault=True) + grand_mean_scaling = traits.Bool(False, usedefault=True) class _IntraModalMergeOutputSpec(TraitedSpec): @@ -34,6 +36,14 @@ class _IntraModalMergeOutputSpec(TraitedSpec): class IntraModalMerge(SimpleInterface): + """ + Calculate an average of the inputs. + + If the input is 3D, returns the original image. + Otherwise, splits the images and merges them after + head-motion correction with FSL ``mcflirt``. + """ + input_spec = _IntraModalMergeInputSpec output_spec = _IntraModalMergeOutputSpec @@ -42,50 +52,72 @@ def _run_interface(self, runtime): if not isinstance(in_files, list): in_files = [self.inputs.in_files] - # Generate output average name early - self._results['out_avg'] = fname_presuffix(self.inputs.in_files[0], - suffix='_avg', newpath=runtime.cwd) - if self.inputs.to_ras: in_files = [reorient(inf, newpath=runtime.cwd) for inf in in_files] - if len(in_files) == 1: - filenii = nb.load(in_files[0]) + run_hmc = self.inputs.hmc and len(in_files) > 1 - # magnitude files can have an extra dimension empty + nii_list = [] + # Remove one-sized extra dimensions + for i, f in enumerate(in_files): + filenii = nb.load(f) + filenii = nb.squeeze_image(filenii) if len(filenii.shape) == 5: - filenii = nb.squeeze_image(filenii) - if len(filenii.shape) == 5: - raise RuntimeError('Input image (%s) is 5D' % in_files[0]) - - in_files = [fname_presuffix(in_files[0], suffix='_squeezed', - newpath=runtime.cwd)] - filenii.to_filename(in_files[0]) - - if filenii.dataobj.ndim < 4: - self._results['out_file'] = in_files[0] - self._results['out_avg'] = in_files[0] - # TODO: generate identity out_mats and zero-filled out_movpar - return runtime - in_files = in_files[0] + raise RuntimeError('Input image (%s) is 5D.' % f) + if filenii.dataobj.ndim == 4: + nii_list += nb.four_to_three(filenii) + else: + nii_list.append(filenii) + + if len(nii_list) > 1: + filenii = nb.concat_images(nii_list) else: - magmrg = fsl.Merge(dimension='t', in_files=self.inputs.in_files) - in_files = magmrg.run().outputs.merged_file - mcflirt = fsl.MCFLIRT(cost='normcorr', save_mats=True, save_plots=True, - ref_vol=0, in_file=in_files) - mcres = mcflirt.run() - self._results['out_mats'] = mcres.outputs.mat_file - self._results['out_movpar'] = mcres.outputs.par_file - self._results['out_file'] = mcres.outputs.out_file - - hmcnii = nb.load(mcres.outputs.out_file) - hmcdat = hmcnii.get_fdata().mean(axis=3) + filenii = nii_list[0] + + merged_fname = fname_presuffix(self.inputs.in_files[0], + suffix='_merged', newpath=runtime.cwd) + filenii.to_filename(merged_fname) + self._results['out_file'] = merged_fname + self._results['out_avg'] = merged_fname + + if filenii.dataobj.ndim < 4: + # TODO: generate identity out_mats and zero-filled out_movpar + return runtime + + if run_hmc: + mcflirt = fsl.MCFLIRT(cost='normcorr', save_mats=True, save_plots=True, + ref_vol=0, in_file=merged_fname) + mcres = mcflirt.run() + filenii = nb.load(mcres.outputs.out_file) + self._results['out_file'] = mcres.outputs.out_file + self._results['out_mats'] = mcres.outputs.mat_file + self._results['out_movpar'] = mcres.outputs.par_file + + hmcdata = filenii.get_fdata(dtype='float32') + if self.inputs.grand_mean_scaling: + if not isdefined(self.inputs.mask): + mean = np.median(hmcdata, axis=-1) + thres = np.percentile(mean, 25) + mask = mean > thres + else: + mask = nb.load(self.inputs.in_mask).get_fdata(dtype='float32') > 0.5 + + nimgs = hmcdata.shape[-1] + means = np.median(hmcdata[mask[..., np.newaxis]].reshape((-1, nimgs)).T, + axis=-1) + max_mean = means.max() + for i in range(nimgs): + hmcdata[..., i] *= max_mean / means[i] + + hmcdata = hmcdata.mean(axis=3) if self.inputs.zero_based_avg: - hmcdat -= hmcdat.min() + hmcdata -= hmcdata.min() + self._results['out_avg'] = fname_presuffix(self.inputs.in_files[0], + suffix='_avg', newpath=runtime.cwd) nb.Nifti1Image( - hmcdat, hmcnii.affine, hmcnii.header).to_filename( + hmcdata, filenii.affine, filenii.header).to_filename( self._results['out_avg']) return runtime diff --git a/niworkflows/interfaces/tests/test_images.py b/niworkflows/interfaces/tests/test_images.py index b40c86f0841..9e6cd358ce9 100644 --- a/niworkflows/interfaces/tests/test_images.py +++ b/niworkflows/interfaces/tests/test_images.py @@ -86,3 +86,35 @@ def test_signal_extraction_equivalence(tmp_path, nvols, nmasks, ext, factor): t2 = toc2 - toc assert t2 < t1 / factor + + +@pytest.mark.parametrize('shape, mshape', [ + ((10, 10, 10), (10, 10, 10)), + ((10, 10, 10, 1), (10, 10, 10)), + ((10, 10, 10, 1, 1), (10, 10, 10)), + ((10, 10, 10, 2), (10, 10, 10, 2)), + ((10, 10, 10, 2, 1), (10, 10, 10, 2)), + ((10, 10, 10, 2, 2), None)]) +def test_IntraModalMerge(tmpdir, shape, mshape): + """Exercise the various types of inputs.""" + tmpdir.chdir() + + data = np.random.normal(size=shape).astype('float32') + fname = str(tmpdir.join('file1.nii.gz')) + nb.Nifti1Image(data, np.eye(4), None).to_filename(fname) + + if mshape is None: + with pytest.raises(RuntimeError): + im.IntraModalMerge(in_files=fname).run() + return + + merged = str(im.IntraModalMerge(in_files=fname).run().outputs.out_file) + merged_data = nb.load(merged).get_fdata(dtype='float32') + assert merged_data.shape == mshape + assert np.allclose(np.squeeze(data), merged_data) + + merged = str(im.IntraModalMerge( + in_files=[fname, fname], hmc=False).run().outputs.out_file) + merged_data = nb.load(merged).get_fdata(dtype='float32') + new_mshape = (*mshape[:3], 2 if len(mshape) == 3 else mshape[3] * 2) + assert merged_data.shape == new_mshape