diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 944100d5..b6e3c77e 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,7 +6,7 @@ from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted import mrpro.utils.unit_conversion __all__ = [ @@ -14,6 +14,7 @@ "fill_range_", "reduce_view", "remove_repeat", + "reshape_broadcasted", "slice_profiles", "smap", "split_idx", diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 435c35b7..c89a12ef 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,7 +1,6 @@ """Tensor reshaping utilities.""" from collections.abc import Sequence -from functools import lru_cache import torch @@ -102,39 +101,38 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc return torch.as_strided(x, newsize, stride) -@lru_cache +# @lru_cache def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]: - """Get reshape reduce index (Cached helper function for reshape_view).""" - # This function tries to group axes from new_shape and old_shape into the smallest groups that have# + """Get reshape reduce index (Cached helper function for reshape_broadcasted).""" + # This function tries to group axes from new_shape and old_shape into the smallest groups that have # the same number of elements, starting from the right. # If all axes of old shape of a group are stride=0 dimensions, # we can reduce them. idx = [] - i, j = len(old_shape), len(new_shape) - while i and j: - product_new = product_old = 1 - grouped = [] - while product_old != product_new or not grouped: - if product_old < product_new: - i -= 1 - grouped.append(i) - product_old *= old_shape[i] + pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right + while pointer_old >= 0: + product_new, product_old = 1, 1 + group: list[int] = [] + while product_old != product_new or not group: + if product_old <= product_new: + product_old *= old_shape[pointer_old] + group.append(pointer_old) + pointer_old -= 1 else: - j -= 1 - product_new *= new_shape[j] + product_new *= new_shape[pointer_new] + pointer_new -= 1 # we found a group - if all(old_stride[d] == 0 for d in grouped): + if all(old_stride[d] == 0 for d in group): # all dimensions are broadcasted # reduce to singleton - idx.extend([slice(1)] * len(grouped)) + idx.extend([slice(1)] * len(group)) else: # preserve - idx.extend([slice(None)] * len(grouped)) - + idx.extend([slice(None)] * len(group)) return idx[::-1] -def reshape(tensor: torch.Tensor, *shape: int) -> torch.Tensor: +def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor: """Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible. Parameters @@ -150,9 +148,13 @@ def reshape(tensor: torch.Tensor, *shape: int) -> torch.Tensor: """ try: + # if we can view the tensor directly, it will preserve broadcasting return tensor.view(shape) except RuntimeError: + if tensor.shape.numel() != torch.Size(shape).numel(): + raise ValueError('Cannot reshape tensor to target shape, number of elements must match') from None idx = _reshape_idx(tensor.shape, shape, tensor.stride()) - # make contiguous in all dimensions in which broadcasting cannot be preserved - semicontiguous = tensor[idx].contiguous().expand(tensor.shape) + # make contiguous only in dimensions in which broadcasting cannot be preserved + semicontiguous = tensor[idx].contiguous() + semicontiguous = semicontiguous.expand(tensor.shape) return semicontiguous.view(shape) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index dd57b8fe..457b2725 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,8 @@ """Tests for reshaping utilities.""" +import pytest import torch -from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, reshape_broadcasted, unsqueeze_left, unsqueeze_right from tests import RandomGenerator @@ -51,3 +52,30 @@ def test_reduce_view(): reduced_one_pos = reduce_view(tensor, 0) assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) + + +@pytest.mark.parametrize( + ('shape', 'expand_shape', 'permute', 'final_shape', 'expected_stride'), + [ + ((1, 2, 3, 1, 1), (1, 2, 3, 4, 5), (0, 2, 1, 3, 4), (1, 6, 2, 2, 5), (6, 1, 0, 0, 0)), + ((1, 2, 1), (100, 2, 2), (0, 1, 2), (100, 4), (0, 1)), + ((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 0, 1), (1, 2, 6, 10, 1), (0, 0, 0, 0, 0)), + ((1, 2, 3), (1, 2, 3), (0, 1, 2), (6,), (1,)), + ], +) +def test_reshape_broadcasted(shape, expand_shape, permute, final_shape, expected_stride): + """Test reshape_broadcasted""" + rng = RandomGenerator(0) + tensor = rng.float32_tensor(shape).expand(*expand_shape).permute(*permute) + reshaped = reshape_broadcasted(tensor, *final_shape) + expected_values = tensor.reshape(*final_shape) + assert reshaped.shape == expected_values.shape + assert reshaped.stride() == expected_stride + assert torch.equal(reshaped, expected_values) + + +def test_reshape_broadcasted_fail(): + """Test reshape_broadcasted with invalid input""" + a = torch.ones(2) + with pytest.raises(ValueError, match='number of elements must match'): + reshape_broadcasted(a, 3)