Skip to content

Commit

Permalink
fix: address @effigies' comments and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Dec 17, 2019
1 parent 797e375 commit 144cc43
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
38 changes: 32 additions & 6 deletions niworkflows/interfaces/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface,
File, InputMultiPath, OutputMultiPath)
File, InputMultiPath, OutputMultiPath, isdefined)
from nipype.interfaces import fsl

LOGGER = logging.getLogger('nipype.interface')
Expand All @@ -21,9 +21,11 @@
class _IntraModalMergeInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiPath(File(exists=True), mandatory=True,
desc='input files')
in_mask = File(exists=True, desc='input mask for grand mean scaling')
hmc = traits.Bool(True, usedefault=True)
zero_based_avg = traits.Bool(True, usedefault=True)
to_ras = traits.Bool(True, usedefault=True)
grand_mean_scaling = traits.Bool(False, usedefault=True)


class _IntraModalMergeOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -54,23 +56,31 @@ def _run_interface(self, runtime):
in_files = [reorient(inf, newpath=runtime.cwd)
for inf in in_files]

run_hmc = len(in_files) > 1
run_hmc = self.inputs.hmc and len(in_files) > 1

nii_list = []
# Remove one-sized extra dimensions
for i, f in enumerate(in_files):
filenii = nb.load(f)
filenii = nb.squeeze_image(filenii)
if len(filenii.shape) == 5:
raise RuntimeError('Input image (%s) is 5D.' % f)
nii_list += nb.four_to_three(filenii)
if filenii.dataobj.ndim == 4:
nii_list += nb.four_to_three(filenii)
else:
nii_list.append(filenii)

if len(nii_list) > 1:
filenii = nb.concat_images(nii_list)
else:
filenii = nii_list[0]

merged_fname = fname_presuffix(self.inputs.in_files[0],
suffix='_merged', newpath=runtime.cwd)
filenii = nb.concat_images(nii_list)
filenii.to_filename(merged_fname)

self._results['out_file'] = merged_fname
self._results['out_avg'] = merged_fname

if filenii.dataobj.ndim < 4:
# TODO: generate identity out_mats and zero-filled out_movpar
return runtime
Expand All @@ -84,7 +94,23 @@ def _run_interface(self, runtime):
self._results['out_mats'] = mcres.outputs.mat_file
self._results['out_movpar'] = mcres.outputs.par_file

hmcdata = filenii.get_fdata(dtype='float32').mean(axis=3)
hmcdata = filenii.get_fdata(dtype='float32')
if self.inputs.grand_mean_scaling:
if not isdefined(self.inputs.mask):
mean = np.median(hmcdata, axis=-1)
thres = np.percentile(mean, 25)
mask = mean > thres
else:
mask = nb.load(self.inputs.in_mask).get_fdata(dtype='float32') > 0.5

nimgs = hmcdata.shape[-1]
means = np.median(hmcdata[mask[..., np.newaxis]].reshape((-1, nimgs)).T,
axis=-1)
max_mean = means.max()
for i in range(nimgs):
hmcdata[..., i] *= max_mean / means[i]

hmcdata = hmcdata.mean(axis=3)
if self.inputs.zero_based_avg:
hmcdata -= hmcdata.min()

Expand Down
32 changes: 32 additions & 0 deletions niworkflows/interfaces/tests/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,35 @@ def test_signal_extraction_equivalence(tmp_path, nvols, nmasks, ext, factor):
t2 = toc2 - toc

assert t2 < t1 / factor


@pytest.mark.parametrize('shape, mshape', [
((10, 10, 10), (10, 10, 10)),
((10, 10, 10, 1), (10, 10, 10)),
((10, 10, 10, 1, 1), (10, 10, 10)),
((10, 10, 10, 2), (10, 10, 10, 2)),
((10, 10, 10, 2, 1), (10, 10, 10, 2)),
((10, 10, 10, 2, 2), None)])
def test_IntraModalMerge(tmpdir, shape, mshape):
"""Exercise the various types of inputs."""
tmpdir.chdir()

data = np.random.normal(size=shape).astype('float32')
fname = str(tmpdir.join('file1.nii.gz'))
nb.Nifti1Image(data, np.eye(4), None).to_filename(fname)

if mshape is None:
with pytest.raises(RuntimeError):
im.IntraModalMerge(in_files=fname).run()
return

merged = str(im.IntraModalMerge(in_files=fname).run().outputs.out_file)
merged_data = nb.load(merged).get_fdata(dtype='float32')
assert merged_data.shape == mshape
assert np.allclose(np.squeeze(data), merged_data)

merged = str(im.IntraModalMerge(
in_files=[fname, fname], hmc=False).run().outputs.out_file)
merged_data = nb.load(merged).get_fdata(dtype='float32')
new_mshape = (*mshape[:3], 2 if len(mshape) == 3 else mshape[3] * 2)
assert merged_data.shape == new_mshape

0 comments on commit 144cc43

Please sign in to comment.