Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add option to infer CIFTI-2 intent codes #932

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
101 changes: 95 additions & 6 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..nifti1 import Nifti1Extensions
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..volumeutils import Recoder
from warnings import warn


Expand Down Expand Up @@ -89,6 +90,53 @@ class Cifti2HeaderError(Exception):
'CIFTI_STRUCTURE_THALAMUS_LEFT',
'CIFTI_STRUCTURE_THALAMUS_RIGHT')

# "Standard CIFTI Mapping Combinations" within CIFTI-2 spec
# https://www.nitrc.org/forum/attachment.php?attachid=341&group_id=454&forum_id=1955
CIFTI_CODES = Recoder((
('.dconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dtseries.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.ptseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dlabel.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS', (
'CIFTI_INDEX_TYPE_LABELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.pdconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dpconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconnseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES',
)),
('.pconnscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SCALARS',
)),
('.dfan.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfibersamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfansamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
), fields=('extension', 'niistring', 'map_types'))


def _value_if_klass(val, klass):
if val is None or isinstance(val, klass):
Expand Down Expand Up @@ -1466,11 +1514,7 @@ def to_file_map(self, file_map=None):
raise ValueError(
f"Dataobj shape {self._dataobj.shape} does not match shape "
f"expected from CIFTI-2 header {self.header.matrix.get_data_shape()}")
# if intent code is not set, default to unknown CIFTI
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
data = reshape_dataobj(self.dataobj,
(1, 1, 1, 1) + self.dataobj.shape)
data = reshape_dataobj(self.dataobj, (1, 1, 1, 1) + self.dataobj.shape)
# If qform not set, reset pixdim values so Nifti2 does not complain
if header['qform_code'] == 0:
header['pixdim'][:4] = 1
Expand Down Expand Up @@ -1501,14 +1545,59 @@ def update_headers(self):
>>> img.shape == (2, 3, 4)
True
"""
self._nifti_header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
header = self._nifti_header
header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
# if intent code is not set, default to unknown
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')

def get_data_dtype(self):
return self._nifti_header.get_data_dtype()

def set_data_dtype(self, dtype):
self._nifti_header.set_data_dtype(dtype)

def to_filename(self, filename, validate=True):
"""
Ensures NIfTI header intent code is set prior to saving.

Parameters
----------
validate : boolean, optional
If ``True``, infer and validate CIFTI type based on MatrixIndicesMap values.
This includes the setting of the relevant intent code within the NIfTI header.
If validation fails, a UserWarning is issued and saving continues.
"""
if validate:
# Determine CIFTI type via index maps
from .parse_cifti2 import intent_codes

matrix = self.header.matrix
map_types = tuple(
matrix.get_index_map(idx).indices_map_to_data_type for idx
in sorted(matrix.mapped_indices)
)
try:
expected_intent = CIFTI_CODES.niistring[map_types]
expected_ext = CIFTI_CODES.extension[map_types]
except KeyError: # unknown
expected_intent = "NIFTI_INTENT_CONNECTIVITY_UNKNOWN"
expected_ext = None
warn(
"No information found for matrix containing the following index maps:"
f"{map_types}, defaulting to unknown."
)

orig_intent = self._nifti_header.get_intent()[0]
if expected_intent != intent_codes.niistring[orig_intent]:
warn(
f"Expected NIfTI intent: {expected_intent} has been automatically set."
)
self._nifti_header.set_intent(expected_intent)
if expected_ext is not None and not filename.endswith(expected_ext):
warn(f"Filename does not end with expected extension: {expected_ext}")
super().to_filename(filename)


load = Cifti2Image.from_filename
save = Cifti2Image.instance_to_filename
15 changes: 15 additions & 0 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,18 @@ def make_imaker(self, arr, header=None, ni_header=None):
)
header.matrix.append(mim)
return lambda: self.image_maker(arr.copy(), header, ni_header)

def validate_filenames(self, imaker, params, validate=False):
super().validate_filenames(imaker, params, validate=validate)

def validate_mmap_parameter(self, imaker, params, validate=False):
super().validate_mmap_parameter(imaker, params, validate=validate)

def validate_to_bytes(self, imaker, params, validate=False):
super().validate_to_bytes(imaker, params, validate=validate)

def validate_from_bytes(self, imaker, params, validate=False):
super().validate_from_bytes(imaker, params, validate=validate)

def validate_to_from_bytes(self, imaker, params, validate=False):
super().validate_to_from_bytes(imaker, params, validate=validate)
2 changes: 1 addition & 1 deletion nibabel/cifti2/tests/test_cifti2io_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def check_rewrite(arr, axes, extension='.nii'):
custom extension to use
"""
(fd, name) = tempfile.mkstemp(extension)
cifti2.Cifti2Image(arr, header=axes).to_filename(name)
cifti2.Cifti2Image(arr, header=axes).to_filename(name, validate=False)
img = nib.load(name)
arr2 = img.get_fdata()
assert np.allclose(arr, arr2)
Expand Down
4 changes: 2 additions & 2 deletions nibabel/cifti2/tests/test_cifti2io_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = ci.load(name)
ci.save(img, 'test.nii')
ci.save(img, 'test.nii', validate=False)
img2 = ci.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
Expand All @@ -109,7 +109,7 @@ def test_nibabel_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = nib.load(name)
nib.save(img, 'test.nii')
nib.save(img, 'test.nii', validate=False)
img2 = nib.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
Expand Down
57 changes: 41 additions & 16 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
scratch.
"""
import numpy as np

import nibabel as nib
from nibabel import cifti2 as ci
from nibabel.tmpdirs import InTemporaryDirectory

import pytest

from ...testing import (
clear_and_catch_warnings, error_warnings, suppress_warnings, assert_array_equal)

Expand Down Expand Up @@ -237,7 +236,6 @@ def test_dtseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.dtseries.nii')
Expand Down Expand Up @@ -281,7 +279,6 @@ def test_dlabel():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')

with InTemporaryDirectory():
ci.save(img, 'test.dlabel.nii')
Expand All @@ -301,7 +298,6 @@ def test_dconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.dconn.nii')
Expand All @@ -323,7 +319,6 @@ def test_ptseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.ptseries.nii')
Expand All @@ -345,7 +340,6 @@ def test_pscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pscalar.nii')
Expand All @@ -367,7 +361,6 @@ def test_pdconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.pdconn.nii')
Expand All @@ -389,7 +382,6 @@ def test_dpconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.dpconn.nii')
Expand All @@ -413,7 +405,7 @@ def test_plabel():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.plabel.nii')
ci.save(img, 'test.plabel.nii', validate=False)
img2 = ci.load('test.plabel.nii')
assert img.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -430,7 +422,6 @@ def test_pconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.pconn.nii')
Expand All @@ -453,8 +444,6 @@ def test_pconnseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 13)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.pconnseries.nii')
Expand All @@ -478,8 +467,6 @@ def test_pconnscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 2)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pconnscalar.nii')
Expand Down Expand Up @@ -517,7 +504,45 @@ def test_wrong_shape():
ci.Cifti2Image(data, hdr)
with suppress_warnings():
img = ci.Cifti2Image(data, hdr)

with pytest.raises(ValueError):
img.to_file_map()


def test_cifti_validation():
# flip label / brain_model index maps
geometry_map = create_geometry_map((0, ))
label_map = create_label_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(geometry_map)
matrix.append(label_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 2)
img = ci.Cifti2Image(data, hdr)
# flipped index maps will warn
with InTemporaryDirectory(), pytest.warns(UserWarning):
ci.save(img, 'test.dlabel.nii')

label_map = create_label_map((0, ))
geometry_map = create_geometry_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.validate.nii', validate=False)
ci.save(img, 'test.dlabel.nii')

img2 = nib.load('test.dlabel.nii')
img3 = nib.load('test.validate.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseLabel'
assert img3.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
assert isinstance(img3, ci.Cifti2Image)
assert_array_equal(img2.get_fdata(), data)
check_label_map(img2.header.matrix.get_index_map(0))
check_geometry_map(img2.header.matrix.get_index_map(1))
del img2, img3
6 changes: 3 additions & 3 deletions nibabel/filebasedimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def filespec_to_file_map(klass, filespec):
def filespec_to_files(klass, filespec):
return klass.filespec_to_file_map(filespec)

def to_filename(self, filename):
def to_filename(self, filename, **kwargs):
""" Write image to files implied by filename string

Parameters
Expand Down Expand Up @@ -381,7 +381,7 @@ def make_file_map(klass, mapping=None):
load = from_filename

@classmethod
def instance_to_filename(klass, img, filename):
def instance_to_filename(klass, img, filename, **kwargs):
""" Save `img` in our own format, to name implied by `filename`

This is a class method
Expand All @@ -394,7 +394,7 @@ def instance_to_filename(klass, img, filename):
Filename, implying name to which to save image.
"""
img = klass.from_image(img)
img.to_filename(filename)
img.to_filename(filename, **kwargs)

@classmethod
def from_image(klass, img):
Expand Down
6 changes: 3 additions & 3 deletions nibabel/loadsave.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def guessed_image_type(filename):
raise ImageFileError(f'Cannot work out file type of "{filename}"')


def save(img, filename):
def save(img, filename, **kwargs):
""" Save an image to file adapting format to `filename`

Parameters
Expand All @@ -96,7 +96,7 @@ def save(img, filename):

# Save the type as expected
try:
img.to_filename(filename)
img.to_filename(filename, **kwargs)
mgxd marked this conversation as resolved.
Show resolved Hide resolved
except ImageFileError:
pass
else:
Expand Down Expand Up @@ -144,7 +144,7 @@ def save(img, filename):
# Here, we either have a klass or a converted image.
if converted is None:
converted = klass.from_image(img)
converted.to_filename(filename)
converted.to_filename(filename, **kwargs)


@deprecate_with_version('read_img_data deprecated. '
Expand Down
Loading