diff --git a/odl/deform/linearized.py b/odl/deform/linearized.py index 8550a2ef13d..63cbf99ba05 100644 --- a/odl/deform/linearized.py +++ b/odl/deform/linearized.py @@ -78,7 +78,7 @@ def __init__(self, template, domain=None): Parameters ---------- - template : `DiscreteLpVector` or array-like + template : `DiscreteLpVector` or element-like Fixed template that is to be deformed. If ``domain`` is not given, ``template`` must be a `DiscreteLpVector`, and the domain of this operator @@ -92,14 +92,34 @@ def __init__(self, template, domain=None): Examples -------- + Create a template and deform it with a given deformation field. + + Where the deformation field is zero we expect to get the same output + as the input. In the 4:th point, the deformation is non-zero and hence + we expect to get the value of the point 0.2 to the left, that is 1.0. + >>> import odl - >>> space = odl.uniform_discr(0, 1, 5) - >>> disp_field_space = odl.ProductSpace(space, space.ndim) + >>> space = odl.uniform_discr(0, 1, 5, interp='nearest') + >>> template = space.element([0, 0, 1, 0, 0]) + >>> op = LinDeformFixedTempl(template) + >>> disp_field = [[0, 0, 0, -0.2, 0]] + >>> print(op(disp_field)) + [0.0, 0.0, 1.0, 1.0, 0.0] + + The result depends on the chosen interpolation. If we chose 'linear' + interpolation and offset the point half the distance between two + points, 0.1, we expect to get the mean of the values. + + >>> space = odl.uniform_discr(0, 1, 5, interp='linear') >>> template = space.element([0, 0, 1, 0, 0]) - >>> displacement_field = disp_field_space.element([[0, 0, 0, -0.2, 0]]) >>> op = LinDeformFixedTempl(template) - >>> op(displacement_field) - uniform_discr(0.0, 1.0, 5).element([0.0, 0.0, 1.0, 1.0, 0.0]) + >>> disp_field = [[0, 0, 0, -0.1, 0]] + >>> print(op(disp_field)) + [0.0, 0.0, 1.0, 0.5, 0.0] + + See Also + -------- + LinDeformFixedDisp : Deformation with a fixed displacement. """ if domain is None: if not isinstance(template, DiscreteLpVector): @@ -121,11 +141,6 @@ def __init__(self, template, domain=None): if not domain[0].is_rn: raise TypeError('`domain[0]` {!r} not a real space' ''.format(domain[0])) - if not domain[0].domain == template.space.domain: - raise TypeError('`domain[0].domain {!r} does not match ' - 'template.space.domain {!r}' - ''.format(domain[0].domain, - template.space.domain)) template = domain[0].element(template) @@ -173,6 +188,17 @@ def derivative(self, displacement): return PointwiseInner(self.domain, def_grad) + def __repr__(self): + """Return ``repr(self)``.""" + if self.domain == self._template.space.tangent_space: + domain_repr = '' + else: + domain_repr = ', domain={!r}'.format(self.domain) + + return '{}({!r}{})'.format(self.__class__.__name__, + self._template, + domain_repr) + class LinDeformFixedDisp(Operator): @@ -187,41 +213,65 @@ def __init__(self, displacement, domain=None): Parameters ---------- - displacement : `ProductSpace` element or array-like + displacement : `ProductSpace` element-like Fixed displacement field used in the linearized deformation. If ``domain`` is not given, ``displacement`` must be a `ProductSpace` element, and the domain of this operator is inferred from ``displacement[0].space``. If ``domain`` is given, ``displacement`` can be anything that is understood - by the ``ProductSpace(domain, domain.ndim).element()`` method. + by the ``domain.tangent_space.element()`` method. domain : `DiscreteLp`, optional Space of templates on which this operator acts, i.e. the operator domain. If not given, ``displacement[0].space`` is used as domain. Examples -------- + Create a given deformation and use it to deform a function. + + Where the deformation field is zero we expect to get the same output + as the input. In the 4:th point, the deformation is non-zero and hence + we expect to get the value of the point 0.2 to the left, that is 1.0. + >>> import odl >>> space = odl.uniform_discr(0, 1, 5) - >>> disp_field_space = odl.ProductSpace(space, space.ndim) - >>> displacement_field = disp_field_space.element([[0, 0, 0, -0.2, 0]]) - >>> template = space.element([0, 0, 1, 0, 0]) - >>> op = LinDeformFixedDisp(displacement_field) - >>> op(template) - uniform_discr(0.0, 1.0, 5).element([0.0, 0.0, 1.0, 1.0, 0.0]) + >>> disp_field = space.tangent_space.element([[0, 0, 0, -0.2, 0]]) + >>> op = LinDeformFixedDisp(disp_field) + >>> template = [0, 0, 1, 0, 0] + >>> print(op([0, 0, 1, 0, 0])) + [0.0, 0.0, 1.0, 1.0, 0.0] + + The result depends on the chosen interpolation. If we chose 'linear' + interpolation and offset the point half the distance between two + points, 0.1, we expect to get the mean of the values. + + >>> space = odl.uniform_discr(0, 1, 5, interp='linear') + >>> disp_field = space.tangent_space.element([[0, 0, 0, -0.1, 0]]) + >>> op = LinDeformFixedDisp(disp_field) + >>> template = [0, 0, 1, 0, 0] + >>> print(op(template)) + [0.0, 0.0, 1.0, 0.5, 0.0] + + See Also + -------- + LinDeformFixedTempl : Deformation with a fixed template. """ if domain is None: - if not isinstance(displacement.space[0], DiscreteLp): - raise TypeError('`displacement[0]` {!r} not an element of' - '`DiscreteLp`'.format(displacement[0])) - if not displacement.space[0].is_rn: - raise TypeError('`displacement[0]` {!r} not a real space' - ''.format(displacement[0])) + if not isinstance(displacement.space, ProductSpace): + raise TypeError('`displacement.space` {!r} not a ' + '`ProductSpace`'.format(displacement.space)) if not displacement.space.is_power_space: - raise TypeError('`displacement.space` {!r} not a product' + raise TypeError('`displacement.space` {!r} not a power' 'space'.format(displacement.space)) + if not isinstance(displacement[0].space, DiscreteLp): + raise TypeError('`displacement[0].space` {!r} not an ' + '`DiscreteLp`'.format(displacement[0])) domain = displacement[0].space else: + if not isinstance(domain, DiscreteLp): + raise TypeError('`displacement[0]` {!r} not an `DiscreteLp`' + ''.format(displacement[0])) + displacement = domain.tangent_space.element(displacement) Operator.__init__(self, domain, domain, linear=True) @@ -257,6 +307,17 @@ def adjoint(self): domain=self.domain) return jacobian_det * deformation + def __repr__(self): + """Return ``repr(self)``.""" + if self.domain == self._displacement.space[0]: + domain_repr = '' + else: + domain_repr = ', domain={!r}'.format(self.domain) + + return '{}({!r}{})'.format(self.__class__.__name__, + self._displacement, + domain_repr) + if __name__ == '__main__': # pylint: disable=wrong-import-position diff --git a/test/deform/linearized_deform_test.py b/test/deform/linearized_deform_test.py index 41e0dc839aa..771df3f1a19 100644 --- a/test/deform/linearized_deform_test.py +++ b/test/deform/linearized_deform_test.py @@ -25,6 +25,7 @@ import pytest import numpy as np import odl +from odl.deform import LinDeformFixedTempl, LinDeformFixedDisp # Set up fixtures @@ -172,12 +173,40 @@ def inv_deform_template(x): return template_function(disp_x) -def test_fixed_templ(space): +# Test implementations start here + + +def test_fixed_templ_init(): + """Verify that the init method and checks work properly.""" + space = odl.uniform_discr(0, 1, 5) + template = space.element(template_function) + + # Valid input + print(LinDeformFixedTempl(template, space.tangent_space)) + print(LinDeformFixedTempl(template_function, domain=space.tangent_space)) + print(LinDeformFixedTempl(template, domain=space.tangent_space)) + print(LinDeformFixedTempl(template=template, domain=space.tangent_space)) + + # Non-valid input + with pytest.raises(TypeError): # domain not product space + LinDeformFixedTempl(template, space) + with pytest.raises(TypeError): # domain wrong type of product space + bad_pspace = odl.ProductSpace(space, odl.rn(3)) + LinDeformFixedTempl(template, bad_pspace) + with pytest.raises(TypeError): # domain product space of non DiscreteLp + bad_pspace = odl.ProductSpace(odl.rn(2), 1) + LinDeformFixedTempl(template, bad_pspace) + with pytest.raises(TypeError): # wrong dtype on domain + wrong_dtype = odl.ProductSpace(space.astype(complex), 1) + LinDeformFixedTempl(template, wrong_dtype) + + +def test_fixed_templ_call(space): """Test deformation for LinDeformFixedTempl.""" # Define the analytic template as the hat function and its gradient template = space.element(template_function) - fixed_templ_op = odl.deform.LinDeformFixedTempl(template) + fixed_templ_op = LinDeformFixedTempl(template) # Calculate result and exact result deform_templ_exact = space.element(deform_template) @@ -197,7 +226,7 @@ def test_fixed_templ_deriv(space): template = space.element(template_function) disp_field = disp_field_factory(space.ndim) vector_field = vector_field_factory(space.ndim) - fixed_templ_op = odl.deform.LinDeformFixedTempl(template) + fixed_templ_op = LinDeformFixedTempl(template) # Calculate result fixed_templ_op_deriv = fixed_templ_op.derivative(disp_field) @@ -212,13 +241,38 @@ def test_fixed_templ_deriv(space): assert rlt_err < error_bound(space.interp) -def test_fixed_disp(space): +def test_fixed_disp_init(): + """Verify that the init method and checks work properly.""" + space = odl.uniform_discr(0, 1, 5) + disp_field = space.tangent_space.element(disp_field_factory(space.ndim)) + + # Valid input + print(LinDeformFixedDisp(disp_field, space)) + print(LinDeformFixedDisp(disp_field, domain=space)) + print(LinDeformFixedDisp(disp_field_factory(space.ndim), domain=space)) + print(LinDeformFixedDisp(displacement=disp_field, domain=space)) + + # Non-valid input + with pytest.raises(TypeError): # domain not DiscreteLp + LinDeformFixedDisp(disp_field, space.tangent_space) + with pytest.raises(TypeError): # domain wrong type of product space + bad_pspace = odl.ProductSpace(space, odl.rn(3)) + LinDeformFixedDisp(disp_field, bad_pspace) + with pytest.raises(TypeError): # domain product space of non DiscreteLp + bad_pspace = odl.ProductSpace(odl.rn(2), 1) + LinDeformFixedDisp(disp_field, bad_pspace) + with pytest.raises(TypeError): # wrong dtype on domain + wrong_dtype = odl.ProductSpace(space.astype(complex), 1) + LinDeformFixedDisp(disp_field, wrong_dtype) + + +def test_fixed_disp_call(space): """Verify that LinDeformFixedDisp produces the correct deformation.""" template = space.element(template_function) disp_field = space.tangent_space.element(disp_field_factory(space.ndim)) # Calculate result and exact result - fixed_disp_op = odl.deform.LinDeformFixedDisp(disp_field, domain=space) + fixed_disp_op = LinDeformFixedDisp(disp_field, domain=space) deform_templ_comp = fixed_disp_op(template) deform_templ_exact = space.element(deform_template) @@ -235,7 +289,7 @@ def test_fixed_disp_adj(space): disp_field = space.tangent_space.element(disp_field_factory(space.ndim)) # Calculate result - fixed_disp_op = odl.deform.LinDeformFixedDisp(disp_field, domain=space) + fixed_disp_op = LinDeformFixedDisp(disp_field, domain=space) fixed_disp_adj_comp = fixed_disp_op.adjoint(template) # Calculate the analytic result