Skip to content

Commit

Permalink
Merge pull request #441 from oesteban/fix/sdcflows-77
Browse files Browse the repository at this point in the history
FIX: ``IntraModalMerge`` failed for dims (x, y, z, 1)
  • Loading branch information
oesteban authored Dec 18, 2019
2 parents c88c521 + 144cc43 commit b7f9955
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 35 deletions.
102 changes: 67 additions & 35 deletions niworkflows/interfaces/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface,
File, InputMultiPath, OutputMultiPath)
File, InputMultiPath, OutputMultiPath, isdefined)
from nipype.interfaces import fsl

LOGGER = logging.getLogger('nipype.interface')
Expand All @@ -21,9 +21,11 @@
class _IntraModalMergeInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiPath(File(exists=True), mandatory=True,
desc='input files')
in_mask = File(exists=True, desc='input mask for grand mean scaling')
hmc = traits.Bool(True, usedefault=True)
zero_based_avg = traits.Bool(True, usedefault=True)
to_ras = traits.Bool(True, usedefault=True)
grand_mean_scaling = traits.Bool(False, usedefault=True)


class _IntraModalMergeOutputSpec(TraitedSpec):
Expand All @@ -34,6 +36,14 @@ class _IntraModalMergeOutputSpec(TraitedSpec):


class IntraModalMerge(SimpleInterface):
"""
Calculate an average of the inputs.
If the input is 3D, returns the original image.
Otherwise, splits the images and merges them after
head-motion correction with FSL ``mcflirt``.
"""

input_spec = _IntraModalMergeInputSpec
output_spec = _IntraModalMergeOutputSpec

Expand All @@ -42,50 +52,72 @@ def _run_interface(self, runtime):
if not isinstance(in_files, list):
in_files = [self.inputs.in_files]

# Generate output average name early
self._results['out_avg'] = fname_presuffix(self.inputs.in_files[0],
suffix='_avg', newpath=runtime.cwd)

if self.inputs.to_ras:
in_files = [reorient(inf, newpath=runtime.cwd)
for inf in in_files]

if len(in_files) == 1:
filenii = nb.load(in_files[0])
run_hmc = self.inputs.hmc and len(in_files) > 1

# magnitude files can have an extra dimension empty
nii_list = []
# Remove one-sized extra dimensions
for i, f in enumerate(in_files):
filenii = nb.load(f)
filenii = nb.squeeze_image(filenii)
if len(filenii.shape) == 5:
filenii = nb.squeeze_image(filenii)
if len(filenii.shape) == 5:
raise RuntimeError('Input image (%s) is 5D' % in_files[0])

in_files = [fname_presuffix(in_files[0], suffix='_squeezed',
newpath=runtime.cwd)]
filenii.to_filename(in_files[0])

if filenii.dataobj.ndim < 4:
self._results['out_file'] = in_files[0]
self._results['out_avg'] = in_files[0]
# TODO: generate identity out_mats and zero-filled out_movpar
return runtime
in_files = in_files[0]
raise RuntimeError('Input image (%s) is 5D.' % f)
if filenii.dataobj.ndim == 4:
nii_list += nb.four_to_three(filenii)
else:
nii_list.append(filenii)

if len(nii_list) > 1:
filenii = nb.concat_images(nii_list)
else:
magmrg = fsl.Merge(dimension='t', in_files=self.inputs.in_files)
in_files = magmrg.run().outputs.merged_file
mcflirt = fsl.MCFLIRT(cost='normcorr', save_mats=True, save_plots=True,
ref_vol=0, in_file=in_files)
mcres = mcflirt.run()
self._results['out_mats'] = mcres.outputs.mat_file
self._results['out_movpar'] = mcres.outputs.par_file
self._results['out_file'] = mcres.outputs.out_file

hmcnii = nb.load(mcres.outputs.out_file)
hmcdat = hmcnii.get_fdata().mean(axis=3)
filenii = nii_list[0]

merged_fname = fname_presuffix(self.inputs.in_files[0],
suffix='_merged', newpath=runtime.cwd)
filenii.to_filename(merged_fname)
self._results['out_file'] = merged_fname
self._results['out_avg'] = merged_fname

if filenii.dataobj.ndim < 4:
# TODO: generate identity out_mats and zero-filled out_movpar
return runtime

if run_hmc:
mcflirt = fsl.MCFLIRT(cost='normcorr', save_mats=True, save_plots=True,
ref_vol=0, in_file=merged_fname)
mcres = mcflirt.run()
filenii = nb.load(mcres.outputs.out_file)
self._results['out_file'] = mcres.outputs.out_file
self._results['out_mats'] = mcres.outputs.mat_file
self._results['out_movpar'] = mcres.outputs.par_file

hmcdata = filenii.get_fdata(dtype='float32')
if self.inputs.grand_mean_scaling:
if not isdefined(self.inputs.mask):
mean = np.median(hmcdata, axis=-1)
thres = np.percentile(mean, 25)
mask = mean > thres
else:
mask = nb.load(self.inputs.in_mask).get_fdata(dtype='float32') > 0.5

nimgs = hmcdata.shape[-1]
means = np.median(hmcdata[mask[..., np.newaxis]].reshape((-1, nimgs)).T,
axis=-1)
max_mean = means.max()
for i in range(nimgs):
hmcdata[..., i] *= max_mean / means[i]

hmcdata = hmcdata.mean(axis=3)
if self.inputs.zero_based_avg:
hmcdat -= hmcdat.min()
hmcdata -= hmcdata.min()

self._results['out_avg'] = fname_presuffix(self.inputs.in_files[0],
suffix='_avg', newpath=runtime.cwd)
nb.Nifti1Image(
hmcdat, hmcnii.affine, hmcnii.header).to_filename(
hmcdata, filenii.affine, filenii.header).to_filename(
self._results['out_avg'])

return runtime
Expand Down
32 changes: 32 additions & 0 deletions niworkflows/interfaces/tests/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,35 @@ def test_signal_extraction_equivalence(tmp_path, nvols, nmasks, ext, factor):
t2 = toc2 - toc

assert t2 < t1 / factor


@pytest.mark.parametrize('shape, mshape', [
((10, 10, 10), (10, 10, 10)),
((10, 10, 10, 1), (10, 10, 10)),
((10, 10, 10, 1, 1), (10, 10, 10)),
((10, 10, 10, 2), (10, 10, 10, 2)),
((10, 10, 10, 2, 1), (10, 10, 10, 2)),
((10, 10, 10, 2, 2), None)])
def test_IntraModalMerge(tmpdir, shape, mshape):
"""Exercise the various types of inputs."""
tmpdir.chdir()

data = np.random.normal(size=shape).astype('float32')
fname = str(tmpdir.join('file1.nii.gz'))
nb.Nifti1Image(data, np.eye(4), None).to_filename(fname)

if mshape is None:
with pytest.raises(RuntimeError):
im.IntraModalMerge(in_files=fname).run()
return

merged = str(im.IntraModalMerge(in_files=fname).run().outputs.out_file)
merged_data = nb.load(merged).get_fdata(dtype='float32')
assert merged_data.shape == mshape
assert np.allclose(np.squeeze(data), merged_data)

merged = str(im.IntraModalMerge(
in_files=[fname, fname], hmc=False).run().outputs.out_file)
merged_data = nb.load(merged).get_fdata(dtype='float32')
new_mshape = (*mshape[:3], 2 if len(mshape) == 3 else mshape[3] * 2)
assert merged_data.shape == new_mshape

0 comments on commit b7f9955

Please sign in to comment.