Skip to content

Commit

Permalink
add rearrange op
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB committed Nov 18, 2024
1 parent 7c5273c commit 0c615dc
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions tests/operators/test_rearrangeop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import pytest
from mrpro.operators.RearrangeOp import RearrangeOp

from tests import RandomGenerator, dotproduct_adjointness_test

from tests import (
RandomGenerator,
dotproduct_adjointness_test,
forward_mode_autodiff_of_linear_operator_test,
gradient_of_linear_operator_test,
)

@pytest.mark.parametrize('dtype', ['float32', 'complex128'])
@pytest.mark.parametrize(
SHAPE_PARAMETERS = pytest.mark.parametrize(
('input_shape', 'rule', 'output_shape', 'additional_info'),
[
((1, 2, 3), 'a b c-> b a c', (2, 1, 3), None), # swap axes
Expand All @@ -16,8 +19,12 @@
],
ids=['swap_axes', 'flatten', 'unflatten'],
)
def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype):
"""Test adjointness and shape."""


@pytest.mark.parametrize('dtype', ['float32', 'complex128'])
@SHAPE_PARAMETERS
def test_einsum_op_adjointness(input_shape, rule, output_shape, additional_info, dtype):
"""Test adjointness and shape of Einsum Op."""
generator = RandomGenerator(seed=0)
generate_tensor = getattr(generator, f'{dtype}_tensor')
u = generate_tensor(size=input_shape)
Expand All @@ -26,6 +33,30 @@ def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype):
dotproduct_adjointness_test(operator, u, v)


@pytest.mark.parametrize('dtype', ['float32', 'complex128'])
@SHAPE_PARAMETERS
def test_einsum_op_grad(input_shape, rule, output_shape, additional_info, dtype):
"""Test gradient of Einsum Op."""
generator = RandomGenerator(seed=0)
generate_tensor = getattr(generator, f'{dtype}_tensor')
u = generate_tensor(size=input_shape)
v = generate_tensor(size=output_shape)
operator = RearrangeOp(rule, additional_info)
gradient_of_linear_operator_test(operator, u, v)


@pytest.mark.parametrize('dtype', ['float32', 'complex128'])
@SHAPE_PARAMETERS
def test_einsum_op_forward_mode_autodiff(input_shape, rule, output_shape, additional_info, dtype):
"""Test forward-mode autodiff of Einsum Op."""
generator = RandomGenerator(seed=0)
generate_tensor = getattr(generator, f'{dtype}_tensor')
u = generate_tensor(size=input_shape)
v = generate_tensor(size=output_shape)
operator = RearrangeOp(rule, additional_info)
forward_mode_autodiff_of_linear_operator_test(operator, u, v)


def test_einsum_op_invalid():
"""Test with invalid rule."""
with pytest.raises(ValueError, match='pattern should match'):
Expand Down

0 comments on commit 0c615dc

Please sign in to comment.