Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add nibabel-based split and merge interfaces #489

Merged
merged 22 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0de7374
[ENH] Add nibabel-based split and merge interfaces per https://github…
Mar 24, 2020
bfcd29c
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
30ddf03
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
5e3c93c
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
40f6de1
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
f3572e6
Chage naming to ConcatImages, fix in/out spec namings
Mar 24, 2020
49ff6bd
Update niworkflows/interfaces/nibabel.py
dPys Mar 31, 2020
87f1cb3
Update niworkflows/interfaces/nibabel.py
dPys Mar 31, 2020
b4fcdb7
rename ConcatImages to MergeSeries, correct typo in description of ou…
Mar 31, 2020
f51f510
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
9aa0655
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
5a31b01
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
53db81a
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
1c27f20
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
05724d5
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
7263d03
Apply suggestions from code review [skip ci]
oesteban Apr 8, 2020
03ebb6d
fix: added few bugfixes and regression tests
oesteban Apr 8, 2020
9446d47
fix: squeeze image with np.squeeze / change input name for consistency
oesteban Apr 8, 2020
0ae5351
sty(black): standardize formatting a bit
oesteban Apr 8, 2020
fc42351
enh: make i/o specs of SplitSeries more consistent [skip ci]
oesteban Apr 8, 2020
1d12dd3
Update niworkflows/interfaces/tests/test_nibabel.py [skip ci]
oesteban Apr 9, 2020
d657546
fix: apply review comments from @effigies, add parameterized tests
oesteban Apr 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 110 additions & 25 deletions niworkflows/interfaces/nibabel.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Nibabel-based interfaces."""
from pathlib import Path
import numpy as np
import nibabel as nb
from nipype import logging
from nipype.utils.filemanip import fname_presuffix
effigies marked this conversation as resolved.
Show resolved Hide resolved
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, File,
SimpleInterface
traits,
TraitedSpec,
BaseInterfaceInputSpec,
File,
SimpleInterface,
OutputMultiObject,
InputMultiObject,
)

IFLOGGER = logging.getLogger('nipype.interface')
IFLOGGER = logging.getLogger("nipype.interface")


class _ApplyMaskInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='an image')
in_mask = File(exists=True, mandatory=True, desc='a mask')
threshold = traits.Float(0.5, usedefault=True,
desc='a threshold to the mask, if it is nonbinary')
in_file = File(exists=True, mandatory=True, desc="an image")
in_mask = File(exists=True, mandatory=True, desc="a mask")
threshold = traits.Float(
0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary"
)


class _ApplyMaskOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_file = File(exists=True, desc="masked file")


class ApplyMask(SimpleInterface):
Expand All @@ -35,8 +42,9 @@ def _run_interface(self, runtime):
msknii = nb.load(self.inputs.in_mask)
msk = msknii.get_fdata() > self.inputs.threshold

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)

if img.dataobj.shape[:3] != msk.shape:
raise ValueError("Image and mask sizes do not match.")
Expand All @@ -48,19 +56,18 @@ def _run_interface(self, runtime):
msk = msk[..., np.newaxis]

masked = img.__class__(img.dataobj * msk, None, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])
return runtime


class _BinarizeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input image')
thresh_low = traits.Float(mandatory=True,
desc='non-inclusive lower threshold')
in_file = File(exists=True, mandatory=True, desc="input image")
thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold")


class _BinarizeOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_mask = File(exists=True, desc='output mask')
out_file = File(exists=True, desc="masked file")
out_mask = File(exists=True, desc="output mask")


class Binarize(SimpleInterface):
Expand All @@ -72,20 +79,98 @@ class Binarize(SimpleInterface):
def _run_interface(self, runtime):
img = nb.load(self.inputs.in_file)

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results['out_mask'] = fname_presuffix(
self.inputs.in_file, suffix='_mask', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)
self._results["out_mask"] = fname_presuffix(
self.inputs.in_file, suffix="_mask", newpath=runtime.cwd
)

data = img.get_fdata()
mask = data > self.inputs.thresh_low
data[~mask] = 0.0
masked = img.__class__(data, img.affine, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])

img.header.set_data_dtype('uint8')
maskimg = img.__class__(mask.astype('uint8'), img.affine,
img.header)
maskimg.to_filename(self._results['out_mask'])
img.header.set_data_dtype("uint8")
maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header)
maskimg.to_filename(self._results["out_mask"])

return runtime

dPys marked this conversation as resolved.
Show resolved Hide resolved

class _SplitSeriesInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc="input 4d image")

oesteban marked this conversation as resolved.
Show resolved Hide resolved

class _SplitSeriesOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images")


class SplitSeries(SimpleInterface):
"""Split a 4D dataset along the last dimension into a series of 3D volumes."""

input_spec = _SplitSeriesInputSpec
output_spec = _SplitSeriesOutputSpec

def _run_interface(self, runtime):
in_file = self.inputs.in_file
img = nb.load(in_file)
extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,)
if len(extra_dims) != 1:
raise ValueError(f"Invalid shape {'x'.join(str(s) for s in img.shape)}")
img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims),
img.affine, img.header)

self._results["out_files"] = []
for i, img_3d in enumerate(nb.four_to_three(img)):
out_file = str(
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
)
img_3d.to_filename(out_file)
self._results["out_files"].append(out_file)

return runtime


class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(
File(exists=True, mandatory=True, desc="input list of 3d images")
)
allow_4D = traits.Bool(
True, usedefault=True, desc="whether 4D images are allowed to be concatenated"
)

oesteban marked this conversation as resolved.
Show resolved Hide resolved

class _MergeSeriesOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="output 4d image")


class MergeSeries(SimpleInterface):
"""Merge a series of 3D volumes along the last dimension into a single 4D image."""

input_spec = _MergeSeriesInputSpec
output_spec = _MergeSeriesOutputSpec

def _run_interface(self, runtime):
nii_list = []
for f in self.inputs.in_files:
filenii = nb.squeeze_image(nb.load(f))
ndim = filenii.dataobj.ndim
if ndim == 3:
nii_list.append(filenii)
continue
elif self.inputs.allow_4D and ndim == 4:
nii_list += nb.four_to_three(filenii)
continue
else:
raise ValueError(
"Input image has an incorrect number of dimensions" f" ({ndim})."
)

img_4d = nb.concat_images(nii_list)
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
img_4d.to_filename(out_file)

self._results["out_file"] = out_file
return runtime
83 changes: 71 additions & 12 deletions niworkflows/interfaces/tests/test_nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import nibabel as nb
import pytest

from ..nibabel import Binarize, ApplyMask
from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries


def test_Binarize(tmp_path):
Expand All @@ -14,10 +14,10 @@ def test_Binarize(tmp_path):
mask = np.zeros((20, 20, 20), dtype=bool)
mask[5:15, 5:15, 5:15] = bool

data = np.zeros_like(mask, dtype='float32')
data = np.zeros_like(mask, dtype="float32")
data[mask] = np.random.gamma(2, size=mask.sum())

in_file = tmp_path / 'input.nii.gz'
in_file = tmp_path / "input.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run()
Expand All @@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path):
mask[8:11, 8:11, 8:11] = 1.0

# Test the 3D
in_file = tmp_path / 'input3D.nii.gz'
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

in_mask = tmp_path / 'mask.nii.gz'
in_mask = tmp_path / "mask.nii.gz"
nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask))

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3

data4d = np.stack((data, 2 * data, 3 * data), axis=-1)
# Test the 4D case
in_file4d = tmp_path / 'input4D.nii.gz'
in_file4d = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d))

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6

# Test errors
nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask))
Expand All @@ -69,3 +73,58 @@ def test_ApplyMask(tmp_path):
ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
with pytest.raises(ValueError):
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()


@pytest.mark.parametrize("shape,exp_n", [
((20, 20, 20, 15), 15),
((20, 20, 20), 1),
((20, 20, 20, 1), 1),
((20, 20, 20, 1, 3), 3),
((20, 20, 20, 3, 1), 3),
((20, 20, 20, 1, 3, 3), -1),
((20, 1, 20, 15), 15),
((20, 1, 20), 1),
((20, 1, 20, 1), 1),
((20, 1, 20, 1, 3), 3),
((20, 1, 20, 3, 1), 3),
((20, 1, 20, 1, 3, 3), -1),
])
def test_SplitSeries(tmp_path, shape, exp_n):
"""Test 4-to-3 NIfTI split interface."""
os.chdir(tmp_path)

in_file = str(tmp_path / "input.nii.gz")
nb.Nifti1Image(np.ones(shape, dtype=float), np.eye(4), None).to_filename(in_file)

_interface = SplitSeries(in_file=in_file)
if exp_n > 0:
split = _interface.run()
n = int(isinstance(split.outputs.out_files, str)) or len(split.outputs.out_files)
assert n == exp_n
else:
with pytest.raises(ValueError):
_interface.run()


def test_MergeSeries(tmp_path):
"""Test 3-to-4 NIfTI concatenation interface."""
os.chdir(str(tmp_path))

in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

merge = MergeSeries(in_files=[str(in_file)] * 5).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

in_4D = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename(
str(in_4D)
)

merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

with pytest.raises(ValueError):
MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()