Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 authored Dec 2, 2024
1 parent b330f20 commit c6caee6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_str
If all axes that are grouped together in an "old" group are stride 0 (=broadcasted)
we can collapse them to singleton dimensions.
This function returns the indexer that either collapses dimensions to singleton or keeps all
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
"""
idx = []
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
Expand All @@ -146,7 +146,7 @@ def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_str
# preserve dimension
idx.extend([slice(None)] * len(group))
idx = idx[::-1] # we worked right to left, but our index should be left to right
return idx]
return idx


def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
Expand All @@ -172,8 +172,8 @@ def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
raise ValueError('Cannot reshape tensor to target shape, number of elements must match') from None
# most of the broadcasted dimensions can be preserved: only dimensions that are joined with non
# broadcasted dimensions can not be preserved and must be made contiguous.
# all dimensions that can be preserved as broadcasted are first collapsed to singleton, such that contiguous does not create
# copies along these axes.
# all dimensions that can be preserved as broadcasted are first collapsed to singleton,
# such that contiguous does not create copies along these axes.
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
# make contiguous only in dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous()
Expand Down

0 comments on commit c6caee6

Please sign in to comment.