Skip to content

Commit

Permalink
DOC: improve doc of linearized deformations
Browse files Browse the repository at this point in the history
  • Loading branch information
adler-j committed Jul 11, 2016
1 parent 7581170 commit ee2b210
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 32 deletions.
113 changes: 87 additions & 26 deletions odl/deform/linearized.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, template, domain=None):
Parameters
----------
template : `DiscreteLpVector` or array-like
template : `DiscreteLpVector` or element-like
Fixed template that is to be deformed.
If ``domain`` is not given, ``template`` must
be a `DiscreteLpVector`, and the domain of this operator
Expand All @@ -92,14 +92,34 @@ def __init__(self, template, domain=None):
Examples
--------
Create a template and deform it with a given deformation field.
Where the deformation field is zero we expect to get the same output
as the input. In the 4:th point, the deformation is non-zero and hence
we expect to get the value of the point 0.2 to the left, that is 1.0.
>>> import odl
>>> space = odl.uniform_discr(0, 1, 5)
>>> disp_field_space = odl.ProductSpace(space, space.ndim)
>>> space = odl.uniform_discr(0, 1, 5, interp='nearest')
>>> template = space.element([0, 0, 1, 0, 0])
>>> op = LinDeformFixedTempl(template)
>>> disp_field = [[0, 0, 0, -0.2, 0]]
>>> print(op(disp_field))
[0.0, 0.0, 1.0, 1.0, 0.0]
The result depends on the chosen interpolation. If we chose 'linear'
interpolation and offset the point half the distance between two
points, 0.1, we expect to get the mean of the values.
>>> space = odl.uniform_discr(0, 1, 5, interp='linear')
>>> template = space.element([0, 0, 1, 0, 0])
>>> displacement_field = disp_field_space.element([[0, 0, 0, -0.2, 0]])
>>> op = LinDeformFixedTempl(template)
>>> op(displacement_field)
uniform_discr(0.0, 1.0, 5).element([0.0, 0.0, 1.0, 1.0, 0.0])
>>> disp_field = [[0, 0, 0, -0.1, 0]]
>>> print(op(disp_field))
[0.0, 0.0, 1.0, 0.5, 0.0]
See Also
--------
LinDeformFixedDisp : Deformation with a fixed displacement.
"""
if domain is None:
if not isinstance(template, DiscreteLpVector):
Expand All @@ -121,11 +141,6 @@ def __init__(self, template, domain=None):
if not domain[0].is_rn:
raise TypeError('`domain[0]` {!r} not a real space'
''.format(domain[0]))
if not domain[0].domain == template.space.domain:
raise TypeError('`domain[0].domain {!r} does not match '
'template.space.domain {!r}'
''.format(domain[0].domain,
template.space.domain))

template = domain[0].element(template)

Expand Down Expand Up @@ -173,6 +188,17 @@ def derivative(self, displacement):

return PointwiseInner(self.domain, def_grad)

def __repr__(self):
"""Return ``repr(self)``."""
if self.domain == self._template.space.tangent_space:
domain_repr = ''
else:
domain_repr = ', domain={!r}'.format(self.domain)

return '{}({!r}{})'.format(self.__class__.__name__,
self._template,
domain_repr)


class LinDeformFixedDisp(Operator):

Expand All @@ -187,41 +213,65 @@ def __init__(self, displacement, domain=None):
Parameters
----------
displacement : `ProductSpace` element or array-like
displacement : `ProductSpace` element-like
Fixed displacement field used in the linearized deformation.
If ``domain`` is not given, ``displacement`` must
be a `ProductSpace` element, and the domain of this operator
is inferred from ``displacement[0].space``. If ``domain`` is
given, ``displacement`` can be anything that is understood
by the ``ProductSpace(domain, domain.ndim).element()`` method.
by the ``domain.tangent_space.element()`` method.
domain : `DiscreteLp`, optional
Space of templates on which this operator acts, i.e. the operator
domain. If not given, ``displacement[0].space`` is used as domain.
Examples
--------
Create a given deformation and use it to deform a function.
Where the deformation field is zero we expect to get the same output
as the input. In the 4:th point, the deformation is non-zero and hence
we expect to get the value of the point 0.2 to the left, that is 1.0.
>>> import odl
>>> space = odl.uniform_discr(0, 1, 5)
>>> disp_field_space = odl.ProductSpace(space, space.ndim)
>>> displacement_field = disp_field_space.element([[0, 0, 0, -0.2, 0]])
>>> template = space.element([0, 0, 1, 0, 0])
>>> op = LinDeformFixedDisp(displacement_field)
>>> op(template)
uniform_discr(0.0, 1.0, 5).element([0.0, 0.0, 1.0, 1.0, 0.0])
>>> disp_field = space.tangent_space.element([[0, 0, 0, -0.2, 0]])
>>> op = LinDeformFixedDisp(disp_field)
>>> template = [0, 0, 1, 0, 0]
>>> print(op([0, 0, 1, 0, 0]))
[0.0, 0.0, 1.0, 1.0, 0.0]
The result depends on the chosen interpolation. If we chose 'linear'
interpolation and offset the point half the distance between two
points, 0.1, we expect to get the mean of the values.
>>> space = odl.uniform_discr(0, 1, 5, interp='linear')
>>> disp_field = space.tangent_space.element([[0, 0, 0, -0.1, 0]])
>>> op = LinDeformFixedDisp(disp_field)
>>> template = [0, 0, 1, 0, 0]
>>> print(op(template))
[0.0, 0.0, 1.0, 0.5, 0.0]
See Also
--------
LinDeformFixedTempl : Deformation with a fixed template.
"""
if domain is None:
if not isinstance(displacement.space[0], DiscreteLp):
raise TypeError('`displacement[0]` {!r} not an element of'
'`DiscreteLp`'.format(displacement[0]))
if not displacement.space[0].is_rn:
raise TypeError('`displacement[0]` {!r} not a real space'
''.format(displacement[0]))
if not isinstance(displacement.space, ProductSpace):
raise TypeError('`displacement.space` {!r} not a '
'`ProductSpace`'.format(displacement.space))
if not displacement.space.is_power_space:
raise TypeError('`displacement.space` {!r} not a product'
raise TypeError('`displacement.space` {!r} not a power'
'space'.format(displacement.space))
if not isinstance(displacement[0].space, DiscreteLp):
raise TypeError('`displacement[0].space` {!r} not an '
'`DiscreteLp`'.format(displacement[0]))

domain = displacement[0].space
else:
if not isinstance(domain, DiscreteLp):
raise TypeError('`displacement[0]` {!r} not an `DiscreteLp`'
''.format(displacement[0]))

displacement = domain.tangent_space.element(displacement)

Operator.__init__(self, domain, domain, linear=True)
Expand Down Expand Up @@ -257,6 +307,17 @@ def adjoint(self):
domain=self.domain)
return jacobian_det * deformation

def __repr__(self):
"""Return ``repr(self)``."""
if self.domain == self._displacement.space[0]:
domain_repr = ''
else:
domain_repr = ', domain={!r}'.format(self.domain)

return '{}({!r}{})'.format(self.__class__.__name__,
self._displacement,
domain_repr)


if __name__ == '__main__':
# pylint: disable=wrong-import-position
Expand Down
66 changes: 60 additions & 6 deletions test/deform/linearized_deform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
import numpy as np
import odl
from odl.deform import LinDeformFixedTempl, LinDeformFixedDisp


# Set up fixtures
Expand Down Expand Up @@ -172,12 +173,40 @@ def inv_deform_template(x):
return template_function(disp_x)


def test_fixed_templ(space):
# Test implementations start here


def test_fixed_templ_init():
"""Verify that the init method and checks work properly."""
space = odl.uniform_discr(0, 1, 5)
template = space.element(template_function)

# Valid input
print(LinDeformFixedTempl(template, space.tangent_space))
print(LinDeformFixedTempl(template_function, domain=space.tangent_space))
print(LinDeformFixedTempl(template, domain=space.tangent_space))
print(LinDeformFixedTempl(template=template, domain=space.tangent_space))

# Non-valid input
with pytest.raises(TypeError): # domain not product space
LinDeformFixedTempl(template, space)
with pytest.raises(TypeError): # domain wrong type of product space
bad_pspace = odl.ProductSpace(space, odl.rn(3))
LinDeformFixedTempl(template, bad_pspace)
with pytest.raises(TypeError): # domain product space of non DiscreteLp
bad_pspace = odl.ProductSpace(odl.rn(2), 1)
LinDeformFixedTempl(template, bad_pspace)
with pytest.raises(TypeError): # wrong dtype on domain
wrong_dtype = odl.ProductSpace(space.astype(complex), 1)
LinDeformFixedTempl(template, wrong_dtype)


def test_fixed_templ_call(space):
"""Test deformation for LinDeformFixedTempl."""

# Define the analytic template as the hat function and its gradient
template = space.element(template_function)
fixed_templ_op = odl.deform.LinDeformFixedTempl(template)
fixed_templ_op = LinDeformFixedTempl(template)

# Calculate result and exact result
deform_templ_exact = space.element(deform_template)
Expand All @@ -197,7 +226,7 @@ def test_fixed_templ_deriv(space):
template = space.element(template_function)
disp_field = disp_field_factory(space.ndim)
vector_field = vector_field_factory(space.ndim)
fixed_templ_op = odl.deform.LinDeformFixedTempl(template)
fixed_templ_op = LinDeformFixedTempl(template)

# Calculate result
fixed_templ_op_deriv = fixed_templ_op.derivative(disp_field)
Expand All @@ -212,13 +241,38 @@ def test_fixed_templ_deriv(space):
assert rlt_err < error_bound(space.interp)


def test_fixed_disp(space):
def test_fixed_disp_init():
"""Verify that the init method and checks work properly."""
space = odl.uniform_discr(0, 1, 5)
disp_field = space.tangent_space.element(disp_field_factory(space.ndim))

# Valid input
print(LinDeformFixedDisp(disp_field, space))
print(LinDeformFixedDisp(disp_field, domain=space))
print(LinDeformFixedDisp(disp_field_factory(space.ndim), domain=space))
print(LinDeformFixedDisp(displacement=disp_field, domain=space))

# Non-valid input
with pytest.raises(TypeError): # domain not DiscreteLp
LinDeformFixedDisp(disp_field, space.tangent_space)
with pytest.raises(TypeError): # domain wrong type of product space
bad_pspace = odl.ProductSpace(space, odl.rn(3))
LinDeformFixedDisp(disp_field, bad_pspace)
with pytest.raises(TypeError): # domain product space of non DiscreteLp
bad_pspace = odl.ProductSpace(odl.rn(2), 1)
LinDeformFixedDisp(disp_field, bad_pspace)
with pytest.raises(TypeError): # wrong dtype on domain
wrong_dtype = odl.ProductSpace(space.astype(complex), 1)
LinDeformFixedDisp(disp_field, wrong_dtype)


def test_fixed_disp_call(space):
"""Verify that LinDeformFixedDisp produces the correct deformation."""
template = space.element(template_function)
disp_field = space.tangent_space.element(disp_field_factory(space.ndim))

# Calculate result and exact result
fixed_disp_op = odl.deform.LinDeformFixedDisp(disp_field, domain=space)
fixed_disp_op = LinDeformFixedDisp(disp_field, domain=space)
deform_templ_comp = fixed_disp_op(template)
deform_templ_exact = space.element(deform_template)

Expand All @@ -235,7 +289,7 @@ def test_fixed_disp_adj(space):
disp_field = space.tangent_space.element(disp_field_factory(space.ndim))

# Calculate result
fixed_disp_op = odl.deform.LinDeformFixedDisp(disp_field, domain=space)
fixed_disp_op = LinDeformFixedDisp(disp_field, domain=space)
fixed_disp_adj_comp = fixed_disp_op.adjoint(template)

# Calculate the analytic result
Expand Down

0 comments on commit ee2b210

Please sign in to comment.