Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Nov 22, 2024
1 parent 64ae029 commit a96b336
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
3 changes: 2 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
import mrpro.utils.unit_conversion

__all__ = [
"broadcast_right",
"fill_range_",
"reduce_view",
"remove_repeat",
"reshape_broadcasted",
"slice_profiles",
"smap",
"split_idx",
Expand Down
46 changes: 24 additions & 22 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Tensor reshaping utilities."""

from collections.abc import Sequence
from functools import lru_cache

import torch

Expand Down Expand Up @@ -102,39 +101,38 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
return torch.as_strided(x, newsize, stride)


@lru_cache
# @lru_cache
def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]:
"""Get reshape reduce index (Cached helper function for reshape_view)."""
# This function tries to group axes from new_shape and old_shape into the smallest groups that have#
"""Get reshape reduce index (Cached helper function for reshape_broadcasted)."""
# This function tries to group axes from new_shape and old_shape into the smallest groups that have
# the same number of elements, starting from the right.
# If all axes of old shape of a group are stride=0 dimensions,
# we can reduce them.
idx = []
i, j = len(old_shape), len(new_shape)
while i and j:
product_new = product_old = 1
grouped = []
while product_old != product_new or not grouped:
if product_old < product_new:
i -= 1
grouped.append(i)
product_old *= old_shape[i]
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
while pointer_old >= 0:
product_new, product_old = 1, 1
group: list[int] = []
while product_old != product_new or not group:
if product_old <= product_new:
product_old *= old_shape[pointer_old]
group.append(pointer_old)
pointer_old -= 1
else:
j -= 1
product_new *= new_shape[j]
product_new *= new_shape[pointer_new]
pointer_new -= 1
# we found a group
if all(old_stride[d] == 0 for d in grouped):
if all(old_stride[d] == 0 for d in group):
# all dimensions are broadcasted
# reduce to singleton
idx.extend([slice(1)] * len(grouped))
idx.extend([slice(1)] * len(group))
else:
# preserve
idx.extend([slice(None)] * len(grouped))

idx.extend([slice(None)] * len(group))
return idx[::-1]


def reshape(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
"""Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible.
Parameters
Expand All @@ -150,9 +148,13 @@ def reshape(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
"""
try:
# if we can view the tensor directly, it will preserve broadcasting
return tensor.view(shape)
except RuntimeError:
if tensor.shape.numel() != torch.Size(shape).numel():
raise ValueError('Cannot reshape tensor to target shape, number of elements must match') from None
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
# make contiguous in all dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous().expand(tensor.shape)
# make contiguous only in dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous()
semicontiguous = semicontiguous.expand(tensor.shape)
return semicontiguous.view(shape)
30 changes: 29 additions & 1 deletion tests/utils/test_reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for reshaping utilities."""

import pytest
import torch
from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right
from mrpro.utils import broadcast_right, reduce_view, reshape_broadcasted, unsqueeze_left, unsqueeze_right

from tests import RandomGenerator

Expand Down Expand Up @@ -51,3 +52,30 @@ def test_reduce_view():
reduced_one_pos = reduce_view(tensor, 0)
assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6)
assert torch.equal(reduced_one_pos.expand_as(tensor), tensor)


@pytest.mark.parametrize(
('shape', 'expand_shape', 'permute', 'final_shape', 'expected_stride'),
[
((1, 2, 3, 1, 1), (1, 2, 3, 4, 5), (0, 2, 1, 3, 4), (1, 6, 2, 2, 5), (6, 1, 0, 0, 0)),
((1, 2, 1), (100, 2, 2), (0, 1, 2), (100, 4), (0, 1)),
((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 0, 1), (1, 2, 6, 10, 1), (0, 0, 0, 0, 0)),
((1, 2, 3), (1, 2, 3), (0, 1, 2), (6,), (1,)),
],
)
def test_reshape_broadcasted(shape, expand_shape, permute, final_shape, expected_stride):
"""Test reshape_broadcasted"""
rng = RandomGenerator(0)
tensor = rng.float32_tensor(shape).expand(*expand_shape).permute(*permute)
reshaped = reshape_broadcasted(tensor, *final_shape)
expected_values = tensor.reshape(*final_shape)
assert reshaped.shape == expected_values.shape
assert reshaped.stride() == expected_stride
assert torch.equal(reshaped, expected_values)


def test_reshape_broadcasted_fail():
"""Test reshape_broadcasted with invalid input"""
a = torch.ones(2)
with pytest.raises(ValueError, match='number of elements must match'):
reshape_broadcasted(a, 3)

0 comments on commit a96b336

Please sign in to comment.