Skip to content

Commit

Permalink
add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 authored Dec 2, 2024
1 parent a96b336 commit 6c97ee5
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,35 +101,52 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
return torch.as_strided(x, newsize, stride)


# @lru_cache
def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]:
"""Get reshape reduce index (Cached helper function for reshape_broadcasted)."""
# This function tries to group axes from new_shape and old_shape into the smallest groups that have
# the same number of elements, starting from the right.
# If all axes of old shape of a group are stride=0 dimensions,
# we can reduce them.
"""Get reshape reduce index (Cached helper function for reshape_broadcasted)
This function tries to group axes from new_shape and old_shape into the smallest groups that have
the same number of elements, starting from the right.
If all axes of old shape of a group are stride=0 dimensions, we can reduce them.
Example:
old_shape = (30, 2, 2, 3)
new_shape = (6. 5, 4, 3)
Will results in the groups (starting from the right):
- old: 3 new: 3
- old: 2, 2 new: 4
- old: 30 new: 6. 5
Only the "old" groups are important.
If all axes that are grouped together in an "old" group are stride 0 (=broadcasted)
we can collapse them to singleton dimensions.
This function returns the indexer that either collapses dimensions to singleton or keeps all
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
"""
idx = []
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
while pointer_old >= 0:
product_new, product_old = 1, 1
product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group
group: list[int] = []
while product_old != product_new or not group:
if product_old <= product_new:
# increase "old" group
product_old *= old_shape[pointer_old]
group.append(pointer_old)
pointer_old -= 1
else:
# increase "new" group
# we don't need to track the new group, the number of elemeents covered.
product_new *= new_shape[pointer_new]
pointer_new -= 1
# we found a group
# we found a group. now we need to decide what to do.
if all(old_stride[d] == 0 for d in group):
# all dimensions are broadcasted
# reduce to singleton
# -> reduce to singleton
idx.extend([slice(1)] * len(group))
else:
# preserve
# preserve dimension
idx.extend([slice(None)] * len(group))
return idx[::-1]
idx = idx[::-1] # we worked right to left, but our index should be left to right
return idx]


def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
Expand All @@ -153,8 +170,13 @@ def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
except RuntimeError:
if tensor.shape.numel() != torch.Size(shape).numel():
raise ValueError('Cannot reshape tensor to target shape, number of elements must match') from None
# most of the broadcasted dimensions can be preserved: only dimensions that are joined with non
# broadcasted dimensions can not be preserved and must be made contiguous.
# all dimensions that can be preserved as broadcasted are first collapsed to singleton, such that contiguous does not create
# copies along these axes.
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
# make contiguous only in dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous()
# finally, we can expand the broadcasted dimensions to the reqested shape
semicontiguous = semicontiguous.expand(tensor.shape)
return semicontiguous.view(shape)

0 comments on commit 6c97ee5

Please sign in to comment.