Skip to content

Commit

Permalink
feat: split transform chains out from base and add a load function
Browse files Browse the repository at this point in the history
This was necessary to integrate one test equivalent to resampling with
``antsApplyTransforms``, but via nitransforms.
  • Loading branch information
oesteban committed Mar 26, 2020
1 parent 3869064 commit 26f9e80
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 141 deletions.
139 changes: 3 additions & 136 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from pathlib import Path
from collections.abc import Iterable
import numpy as np
import h5py
import warnings
Expand Down Expand Up @@ -168,10 +167,10 @@ def __ne__(self, other):
return not self == other


class TransformBase(object):
class TransformBase:
"""Abstract image class to represent transforms."""

__slots__ = ['_reference']
__slots__ = ('_reference', )

def __init__(self, reference=None):
"""Instantiate a transform."""
Expand All @@ -191,13 +190,11 @@ def __add__(self, b):
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase()
>>> isinstance(added, TransformChain)
True
>>> len(added.transforms)
2
"""
from .manip import TransformChain
return TransformChain(transforms=[self, b])

@property
Expand Down Expand Up @@ -322,127 +319,6 @@ def _to_hdf5(self, x5_root):
raise NotImplementedError


class TransformChain(TransformBase):
"""Implements the concatenation of transforms."""

__slots__ = ['_transforms']

def __init__(self, transforms=None):
"""Initialize a chain of transforms."""
self._transforms = None
if transforms is not None:
self.transforms = transforms

def __add__(self, b):
"""
Compose this and other transforms.
Example
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase() + TransformBase()
>>> isinstance(added, TransformChain)
True
>>> len(added.transforms)
3
"""
self.append(b)
return self

def __getitem__(self, i):
"""
Enable indexed access of transform chains.
Example
-------
>>> T1 = TransformBase()
>>> chain = T1 + TransformBase()
>>> chain[0] is T1
True
"""
return self.transforms[i]

def __len__(self):
"""Enable using len()."""
return len(self.transforms)

@property
def transforms(self):
"""Get the internal list of transforms."""
return self._transforms

@transforms.setter
def transforms(self, value):
self._transforms = _as_chain(value)
if self.transforms[0].reference:
self.reference = self.transforms[0].reference

def append(self, x):
"""
Concatenate one element to the chain.
Example
-------
>>> chain = TransformChain(transforms=TransformBase())
>>> chain.append((TransformBase(), TransformBase()))
>>> len(chain)
3
"""
self.transforms += _as_chain(x)

def insert(self, i, x):
"""
Insert an item at a given position.
Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain.insert(1, TransformBase())
>>> len(chain)
3
>>> chain.insert(1, TransformChain(chain))
>>> len(chain)
6
"""
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]

def map(self, x, inverse=False):
"""
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TransformError:
"""
if not self.transforms:
raise TransformError('Cannot apply an empty transforms chain.')

transforms = self.transforms
if inverse:
transforms = reversed(self.transforms)

for xfm in transforms:
x = xfm(x, inverse=inverse)

return x


def _as_homogeneous(xyz, dtype='float32', dim=3):
"""
Convert 2D and 3D coordinates into homogeneous coordinates.
Expand Down Expand Up @@ -473,12 +349,3 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
def _apply_affine(x, affine, dim):
"""Get the image array's indexes corresponding to coordinates."""
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T


def _as_chain(x):
"""Convert a value into a transform chain."""
if isinstance(x, TransformChain):
return x.transforms
if isinstance(x, Iterable):
return list(x)
return [x]
2 changes: 1 addition & 1 deletion nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def from_h5obj(cls, fileobj, check=True):
except KeyError:
typo_fallback = "Tranform"

for xfm in reversed(list(h5group.values())[1:]):
for xfm in list(h5group.values())[1:]:
if xfm["TransformType"][0].startswith(b"AffineTransform"):
_params = np.asanyarray(xfm[f"{typo_fallback}Parameters"])
xfm_list.append(
Expand Down
172 changes: 172 additions & 0 deletions nitransforms/manip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
# See COPYING file distributed along with the NiBabel package for the
# copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from collections.abc import Iterable

from .base import (
TransformBase,
TransformError,
)
from .linear import Affine
from .nonlinear import DisplacementsFieldTransform


class TransformChain(TransformBase):
"""Implements the concatenation of transforms."""

__slots__ = ('_transforms', )

def __init__(self, transforms=None):
"""Initialize a chain of transforms."""
super().__init__()
self._transforms = None

if transforms is not None:
self.transforms = transforms

def __add__(self, b):
"""
Compose this and other transforms.
Example
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase() + TransformBase()
>>> isinstance(added, TransformChain)
True
>>> len(added.transforms)
3
"""
self.append(b)
return self

def __getitem__(self, i):
"""
Enable indexed access of transform chains.
Example
-------
>>> T1 = TransformBase()
>>> chain = T1 + TransformBase()
>>> chain[0] is T1
True
"""
return self.transforms[i]

def __len__(self):
"""Enable using len()."""
return len(self.transforms)

@property
def transforms(self):
"""Get the internal list of transforms."""
return self._transforms

@transforms.setter
def transforms(self, value):
self._transforms = _as_chain(value)
if self.transforms[-1].reference:
self.reference = self.transforms[-1].reference

def append(self, x):
"""
Concatenate one element to the chain.
Example
-------
>>> chain = TransformChain(transforms=TransformBase())
>>> chain.append((TransformBase(), TransformBase()))
>>> len(chain)
3
"""
self.transforms += _as_chain(x)

def insert(self, i, x):
"""
Insert an item at a given position.
Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain.insert(1, TransformBase())
>>> len(chain)
3
>>> chain.insert(1, TransformChain(chain))
>>> len(chain)
6
"""
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]

def map(self, x, inverse=False):
"""
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TransformError:
"""
if not self.transforms:
raise TransformError('Cannot apply an empty transforms chain.')

transforms = self.transforms
if not inverse:
transforms = self.transforms[::-1]

for xfm in transforms:
x = xfm(x, inverse=inverse)

return x

@classmethod
def from_filename(cls, filename, fmt="X5",
reference=None, moving=None):
"""Load a transform file."""
from .io import itk

retval = []
if str(filename).endswith(".h5"):
reference = None
xforms = itk.ITKCompositeH5.from_filename(filename)
for xfmobj in xforms:
if isinstance(xfmobj, itk.ITKLinearTransform):
retval.append(Affine(xfmobj.to_ras(), reference=reference))
else:
retval.append(DisplacementsFieldTransform(xfmobj))

return TransformChain(retval)

raise NotImplementedError


def _as_chain(x):
"""Convert a value into a transform chain."""
if isinstance(x, TransformChain):
return x.transforms
if isinstance(x, Iterable):
return list(x)
return [x]


load = TransformChain.from_filename
8 changes: 4 additions & 4 deletions nitransforms/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,14 @@ def test_afni_Displacements():

def test_itk_h5(data_path):
"""Test displacements fields."""
itk.ITKCompositeH5.from_filename(
assert len(list(itk.ITKCompositeH5.from_filename(
data_path / 'ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5'
)
))) == 2

with pytest.raises(RuntimeError):
itk.ITKCompositeH5.from_filename(
list(itk.ITKCompositeH5.from_filename(
data_path / 'ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.x5'
)
))


@pytest.mark.parametrize('file_type, test_file', [
Expand Down
Loading

0 comments on commit 26f9e80

Please sign in to comment.