Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Read (and apply) ITK/ANTs' composite HDF5 transforms #79

Merged
merged 5 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 3 additions & 136 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from pathlib import Path
from collections.abc import Iterable
import numpy as np
import h5py
import warnings
Expand Down Expand Up @@ -168,10 +167,10 @@ def __ne__(self, other):
return not self == other


class TransformBase(object):
class TransformBase:
"""Abstract image class to represent transforms."""

__slots__ = ['_reference']
__slots__ = ('_reference', )

def __init__(self, reference=None):
"""Instantiate a transform."""
Expand All @@ -191,13 +190,11 @@ def __add__(self, b):
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase()
>>> isinstance(added, TransformChain)
True

>>> len(added.transforms)
2

"""
from .manip import TransformChain
return TransformChain(transforms=[self, b])

@property
Expand Down Expand Up @@ -322,127 +319,6 @@ def _to_hdf5(self, x5_root):
raise NotImplementedError


class TransformChain(TransformBase):
"""Implements the concatenation of transforms."""

__slots__ = ['_transforms']

def __init__(self, transforms=None):
"""Initialize a chain of transforms."""
self._transforms = None
if transforms is not None:
self.transforms = transforms

def __add__(self, b):
"""
Compose this and other transforms.

Example
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase() + TransformBase()
>>> isinstance(added, TransformChain)
True

>>> len(added.transforms)
3

"""
self.append(b)
return self

def __getitem__(self, i):
"""
Enable indexed access of transform chains.

Example
-------
>>> T1 = TransformBase()
>>> chain = T1 + TransformBase()
>>> chain[0] is T1
True

"""
return self.transforms[i]

def __len__(self):
"""Enable using len()."""
return len(self.transforms)

@property
def transforms(self):
"""Get the internal list of transforms."""
return self._transforms

@transforms.setter
def transforms(self, value):
self._transforms = _as_chain(value)
if self.transforms[0].reference:
self.reference = self.transforms[0].reference

def append(self, x):
"""
Concatenate one element to the chain.

Example
-------
>>> chain = TransformChain(transforms=TransformBase())
>>> chain.append((TransformBase(), TransformBase()))
>>> len(chain)
3

"""
self.transforms += _as_chain(x)

def insert(self, i, x):
"""
Insert an item at a given position.

Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain.insert(1, TransformBase())
>>> len(chain)
3

>>> chain.insert(1, TransformChain(chain))
>>> len(chain)
6

"""
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]

def map(self, x, inverse=False):
"""
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.

Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]

>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]

>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TransformError:

"""
if not self.transforms:
raise TransformError('Cannot apply an empty transforms chain.')

transforms = self.transforms
if inverse:
transforms = reversed(self.transforms)

for xfm in transforms:
x = xfm(x, inverse=inverse)

return x


def _as_homogeneous(xyz, dtype='float32', dim=3):
"""
Convert 2D and 3D coordinates into homogeneous coordinates.
Expand Down Expand Up @@ -473,12 +349,3 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
def _apply_affine(x, affine, dim):
"""Get the image array's indexes corresponding to coordinates."""
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T


def _as_chain(x):
"""Convert a value into a transform chain."""
if isinstance(x, TransformChain):
return x.transforms
if isinstance(x, Iterable):
return list(x)
return [x]
67 changes: 66 additions & 1 deletion nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
import numpy as np
from scipy.io import savemat as _save_mat
from nibabel import Nifti1Header, Nifti1Image
from nibabel.affines import from_matvec
from .base import (
BaseLinearTransformList,
Expand Down Expand Up @@ -29,7 +30,9 @@ def __init__(self, parameters=None, offset=None):
"""Initialize with default offset and index."""
super().__init__()
self.structarr['index'] = 0
self.structarr['offset'] = offset or [0, 0, 0]
if offset is None:
offset = np.zeros((3,), dtype='float')
self.structarr['offset'] = offset
self.structarr['parameters'] = np.eye(4)
if parameters is not None:
self.structarr['parameters'] = parameters
Expand Down Expand Up @@ -280,3 +283,65 @@ def from_image(cls, imgobj):
field[..., (0, 1)] *= -1.0

return imgobj.__class__(field, imgobj.affine, hdr)


class ITKCompositeH5:
"""A data structure for ITK's HDF5 files."""

@classmethod
def from_filename(cls, filename):
"""Read the struct from a file given its path."""
from h5py import File as H5File
if not str(filename).endswith('.h5'):
raise RuntimeError("Extension is not .h5")

with H5File(str(filename)) as f:
return cls.from_h5obj(f)

@classmethod
def from_h5obj(cls, fileobj, check=True):
"""Read the struct from a file object."""
xfm_list = []
h5group = fileobj["TransformGroup"]
typo_fallback = "Transform"
try:
h5group['1'][f"{typo_fallback}Parameters"]
except KeyError:
typo_fallback = "Tranform"

for xfm in list(h5group.values())[1:]:
if xfm["TransformType"][0].startswith(b"AffineTransform"):
_params = np.asanyarray(xfm[f"{typo_fallback}Parameters"])
xfm_list.append(
ITKLinearTransform(
parameters=from_matvec(_params[:-3].reshape(3, 3), _params[-3:]),
offset=np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
)
)
continue
if xfm["TransformType"][0].startswith(b"DisplacementFieldTransform"):
_fixed = np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
shape = _fixed[:3].astype('uint16').tolist()
offset = _fixed[3:6].astype('uint16')
zooms = _fixed[6:9].astype('float')
directions = _fixed[9:].astype('float').reshape((3, 3))
affine = from_matvec(directions * zooms, offset)
field = np.asanyarray(xfm[f"{typo_fallback}Parameters"]).reshape(
tuple(shape + [1, -1])
)
hdr = Nifti1Header()
hdr.set_intent("vector")
hdr.set_data_dtype("float")

xfm_list.append(
ITKDisplacementsField.from_image(
Nifti1Image(field.astype("float"), affine, hdr)
)
)
continue

raise NotImplementedError(
f"Unsupported transform type {xfm['TransformType'][0]}"
)

return xfm_list
Loading