diff --git a/src/mrpro/data/KTrajectory.py b/src/mrpro/data/KTrajectory.py index 7a508279a..d38d76f7e 100644 --- a/src/mrpro/data/KTrajectory.py +++ b/src/mrpro/data/KTrajectory.py @@ -10,7 +10,7 @@ from mrpro.data.enums import TrajType from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.utils import reduce_repeat +from mrpro.utils import remove_repeat from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues @@ -54,7 +54,7 @@ def as_any_float(tensor: torch.Tensor) -> torch.Tensor: if self.repeat_detection_tolerance is not None: kz, ky, kx = ( - as_any_float(reduce_repeat(tensor, self.repeat_detection_tolerance)) + as_any_float(remove_repeat(tensor, self.repeat_detection_tolerance)) for tensor in (self.kz, self.ky, self.kx) ) # use of setattr due to frozen dataclass diff --git a/src/mrpro/data/ReduceRepeatMixin.py b/src/mrpro/data/ReduceRepeatMixin.py deleted file mode 100644 index 16c2257a9..000000000 --- a/src/mrpro/data/ReduceRepeatMixin.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Mixin to reduce dimensions with repeated values to singleton in fields of dataclasses.""" - -import dataclasses -from collections.abc import Sequence -from typing import TypeVar, cast - -import torch - -from mrpro.data.Rotation import Rotation -from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.utils.reduce_repeat import reduce_repeat as remove_repeat_tensor -from mrpro.utils.typing import DataclassInstance - -T = TypeVar('T') - - -def remove_repeat(data: T, tol: float, dim: Sequence[int] | None = None) -> T: - """Replace dimensions with all equal values with singletons in fields. - - Handles Tensor, Rotation, and SpatialDimension fields. - - Parameters - ---------- - data: - Input data, must be real. - tol: - tolerance. - dim - dimensions to try to reduce to singletons. `None` means all. - """ - match data: - case torch.Tensor(): - return cast(T, remove_repeat_tensor(data, tol, dim)) - case SpatialDimension(z, y, x) if ( - isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor) and isinstance(z, torch.Tensor) - ): - return cast( - T, - data.__class__( - x=remove_repeat_tensor(x, tol, dim), - y=remove_repeat_tensor(y, tol, dim), - z=remove_repeat_tensor(z, tol, dim), - ), - ) - case Rotation(): - tensor = data.as_matrix().flatten(start_dim=-2) - tensor = torch.stack([remove_repeat_tensor(x, tol, dim) for x in tensor.unbind(-1)], -1).unflatten( - -1, (3, 3) - ) - return cast(T, data.__class__.from_matrix(tensor)) - case _: - return data - - -@dataclasses.dataclass -class ReduceRepeatMixin(DataclassInstance): - """Adds a __post_init__ method to remove repeated dimensions fields.""" - - def __init_subclass__(cls, reduce_repeat: bool = True, **kwargs): - """Initialize a checked data subclass.""" - super().__init_subclass__(**kwargs) - if reduce_repeat: - # inject the new post_init method - original_post_init = vars(cls).get('__post_init__') - - def new_post_init(self: ReduceRepeatMixin) -> None: - for field in dataclasses.fields(self): - setattr(self, field.name, remove_repeat(getattr(self, field.name), 1e-6)) - if original_post_init is not None: - original_post_init(self) - - cls.__post_init__ = new_post_init # type: ignore[attr-defined] diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index 27ae7aa26..8f9d6c962 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -17,8 +17,6 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.ReduceRepeatMixin import ReduceRepeatMixin - __all__ = [ "AcqIdx", "AcqInfo", @@ -38,7 +36,6 @@ "MoveDataMixin", "QData", "QHeader", - "ReduceRepeatMixin", "Rotation", "SpatialDimension", "acq_filters", diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index ccef685d7..7172847cf 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -5,22 +5,18 @@ from mrpro.utils import unit_conversion from mrpro.utils.fill_range import fill_range_ from mrpro.utils.smap import smap -from mrpro.utils.indexing import Indexer -from mrpro.utils.reduce_repeat import reduce_repeat +from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted, ravel_multi_index, unsqueeze_tensors_left, unsqueeze_tensors_right, unsqueeze_at, unsqueeze_tensors_at from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin -from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop - __all__ = [ - "Indexer", "TensorAttributeMixin", "broadcast_right", "fill_range_", "ravel_multi_index", - "reduce_repeat", "reduce_view", + "remove_repeat", "reshape_broadcasted", "slice_profiles", "smap", @@ -34,4 +30,4 @@ "unsqueeze_tensors_left", "unsqueeze_tensors_right", "zero_pad_or_crop" -] +] \ No newline at end of file diff --git a/src/mrpro/utils/indexing.py b/src/mrpro/utils/indexing.py deleted file mode 100644 index c187a77f5..000000000 --- a/src/mrpro/utils/indexing.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Indexer class for custom indexing with broadcasting.""" - -from collections.abc import Sequence -from typing import cast - -import torch -import torch.testing - -from mrpro.utils.reshape import reduce_view -from mrpro.utils.typing import TorchIndexerType - - -class Indexer: - """Custom Indexing with broadcasting. - - This class is used to index tensors in a way that is consistent - with the shape invariants of the data objects. - - On creation, an index and a shape are required. - When calling the Indexer with a tensor, the tensor is first broadcasted to the shape, - then the index is applied to the tensor. - - After indexing, remaining broadcasted dimensions are reduced to singleton dimensions again. - Thus, using the same Indexer on tensors with different singleton dimensions will - result in tensors with different shapes. All resulting tensors can be broadcasted - to the shape that would result in indexing a full tensor already having the desired shape. - - Indexing never removes dimensions, and can only add new dimensions at the beginning of the tensor. - - The index can contain slices, integers, boolean masks, sequences of integers, and integer tensors: - - Indexing with a slice - Behaves like in numpy, always returns a view. - Negative step sizes are not supported and will raise an IndexError. - slice(None), i.e. :, means selecting the whole axis. - - Indexing with an integer - If the index in bounds of the broadcasted shape, indexing behaves like slicing with index:index+1. - Otherwise, an IndexError is raised. - Always returns a view. - - Indexing with a boolean mask - Singleton dimensions in the mask are interpreted as full slices. This matches broadcasting of the mask to - the size of the respective axes of the tensor. - If the mask has more than one non-singleton dimension, a new dimension is added at the beginning of the tensor, - with length equal to the number of True values in the mask. - At the indexed axes, singleton dimensions are kept. - If the mask has only one non-singleton dimension, only the size of the indexed axes is changed. - Only a single boolean mask is allowed, otherwise an IndexError is raised. - - Indexing with a sequence of integers - If a single indexer is a sequence of integers, the result is as if each value of the sequence was used as an - integer index and the results were concatenated along the indexed dimension. - If more than one sequence of integers is used, a new dimension at the beginning of the tensor, - with the length equal to the shape of the sequences, is added. Indexed dimensions are kept as singleton. - The different sequences must have the same shape, otherwise an IndexError is raised. - Note that, as in numpy and torch, vectorized indexing is performed, not outer indexing. - - None - New axes can be added to the front of tensor by using None in the index. - This is only allowed at the beginning of the index. - - Ellipsis - An indexing expression can contain a single ellipsis, which will be expanded to slice(None) - for all axes that are not indexed. - - Implementation details: - - On creation, the indexing expression is parsed and split into two parts: normal_index and fancy_index. - - normal_index contains only indexing expressions that can be represented as view. - - fancy_index contains all other indexing expressions. - - On call - - the tensor is broadcasted to the desired shape - - the normal_index is applied. - - if required, the fancy_index is applied. - - remaining broadcasted dimensions are reduced to singleton dimensions. - """ - - def __init__(self, shape: tuple[int, ...], index: tuple[TorchIndexerType, ...]) -> None: - """Initialize the Indexer. - - Parameters - ---------- - shape - broadcasted shape of the tensors to index. All tensors will be broadcasted to this shape. - index - The index to apply to the tensors. - """ - normal_index: list[slice | int | None] = [] - """Used in phase 1 of the indexing, where we only consider integers and slices. Always does a view""" - fancy_index: list[slice | torch.Tensor | tuple[int, ...] | None] = [] - """All non normal indices. Might not be possible to do a view.""" - has_fancy_index = False - """Are there any advanced indices, such as boolean or integer array indices?""" - vectorized_shape: None | tuple[int, ...] = None - """Number of dimensions of the integer indices""" - expanded_index: list[slice | torch.Tensor | tuple[int, ...] | None | int] = [] - """"index with ellipsis expanded to full slices""" - - # basics checks and figuring out the number of axes already covered by the index, - # which is needed to determine the number of axes that covered by the ellipsis - has_ellipsis = False - has_boolean = False - covered_axes = 0 - for idx_ in index: - if idx_ is None: - if has_ellipsis or covered_axes: - raise IndexError('New axes are only allowed at the beginning of the index') - elif idx_ is Ellipsis: - if has_ellipsis: - raise IndexError('Only one ellipsis is allowed') - has_ellipsis = True - elif isinstance(idx_, torch.Tensor) and idx_.dtype == torch.bool: - if has_boolean: - raise IndexError('Only one boolean index is allowed') - has_boolean = True - covered_axes += idx_.ndim - elif isinstance(idx_, int | slice | torch.Tensor) or ( - isinstance(idx_, Sequence) and all(isinstance(el, int) for el in idx_) - ): - covered_axes += 1 - else: - raise IndexError(f'Unsupported index type {idx_}') - - if covered_axes > len(shape): - raise IndexError('Too many indices. Indexing more than the number of axes is not allowed') - - for idx_ in index: - if idx_ is Ellipsis: - # replacing ellipsis with full slices - expanded_index.extend([slice(None)] * (len(shape) - covered_axes)) - elif isinstance(idx_, torch.Tensor | int | slice | None): - expanded_index.append(idx_) - else: # must be Sequence[int], checked above - # for consistency, we convert all non-tensor sequences of integers to tuples - expanded_index.append(tuple(cast(Sequence[int], idx_))) - - if not has_ellipsis: - # if there is not ellipsis, we interpret the index as if it was followed by ellipsis - expanded_index.extend([slice(None)] * (len(shape) - covered_axes)) - - number_of_vectorized_indices: int = 0 - shape_position: int = 0 # current position in the shape that we are indexing - for idx in expanded_index: - if idx is None: - # we already checked that None is only allowed at the beginning of the index - normal_index.append(None) - fancy_index.append(slice(None)) - - elif isinstance(idx, int): - # always convert integers to slices - if not -shape[shape_position] <= idx < shape[shape_position]: - raise IndexError( - f'Index {idx} out of bounds for axis {shape_position} with shape {shape[shape_position]}' - ) - normal_index.append(slice(idx, idx + 1)) - fancy_index.append(slice(None)) - shape_position += 1 - - elif isinstance(idx, slice): - if idx.step is not None and idx.step < 0: - raise IndexError('Negative step size for slices is not supported') - normal_index.append(idx) - fancy_index.append(slice(None)) - shape_position += 1 - continue - - elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool: - # boolean indexing - has_fancy_index = True - - while idx.ndim and idx.shape[0] == 1: - # remove leading singleton dimensions and replace by full slices - idx = idx.squeeze(0) - fancy_index.append(slice(None)) - normal_index.append(slice(None)) - shape_position += 1 - - right_slice = [] - while idx.ndim and idx.shape[-1] == 1: - # remove trailing singleton dimensions and replace by full slices - idx = idx.squeeze(-1) - right_slice.append(slice(None)) - - if idx.ndim == 1: - # single boolean dimension remains - fancy_index.extend(torch.nonzero(idx, as_tuple=True)) - number_of_vectorized_indices += 1 - - elif idx.ndim > 1: - # more than one non singleton dimension - for ids, idx_shape, data_shape in zip( - idx.nonzero(as_tuple=True), - idx.shape, - shape[shape_position : shape_position + idx.ndim], - strict=True, - ): - if idx_shape == 1: - # we interpret singleton dimensions as full slices - fancy_index.append(slice(None)) - elif idx_shape != data_shape: - raise IndexError( - f'Boolean index has wrong shape, got {idx_shape} but expected {data_shape}' - ) - - else: - fancy_index.append(ids) - number_of_vectorized_indices += 1 - else: - # all singleton boolean mask - pass - - normal_index.extend(idx.ndim * [slice(None)]) - normal_index.extend(right_slice) - fancy_index.extend(right_slice) - shape_position += idx.ndim + len(right_slice) - - elif isinstance(idx, torch.Tensor) and idx.dtype in ( - torch.int64, # long - torch.int32, # int - torch.int16, - torch.int8, - torch.uint16, - torch.uint32, - torch.uint64, - ): - # integer array indexing - if (idx >= shape[shape_position]).any() or (idx < -shape[shape_position]).any(): - raise IndexError( - 'Index out of bounds. ' - f'Got values in the inerval [{idx.min()}, {idx.max() + 1}) for axis {shape_position} ' - f'with shape {shape[shape_position]}' - ) - if vectorized_shape is not None and vectorized_shape != idx.shape: - raise IndexError( - f'All vectorized indices must have the same shape. Got {idx.shape} and {vectorized_shape}' - ) - vectorized_shape = idx.shape - has_fancy_index = True - shape_position += 1 - number_of_vectorized_indices += 1 - normal_index.append(slice(None)) - fancy_index.append(idx.to(torch.int64)) - - elif isinstance(idx, tuple): - # integer Sequence - if any(el >= shape[shape_position] or el < -shape[shape_position] for el in idx): - raise IndexError( - 'Index out of bounds. ' - f'Got values in the interval [{min(idx)}, {max(idx) + 1}) for axis {shape_position} ' - f'with shape {shape[shape_position]}' - ) - if vectorized_shape is not None and vectorized_shape != (len(idx),): - raise IndexError('All vectorized indices must have the same shape') - vectorized_shape = (len(idx),) - has_fancy_index = True - shape_position += 1 - normal_index.append(slice(None)) - fancy_index.append(idx) - number_of_vectorized_indices += 1 - - else: # torch.Tensor - raise IndexError(f'Unsupported index dtype {idx.dtype}') - - self.move_axes: tuple[tuple[int, ...], tuple[int, ...]] = ((), ()) - """final move-axes operation to move the vectorized indices to the beginning of the tensor""" - self.more_than_one_vectorized_index = number_of_vectorized_indices > 1 - """there is more than one vectorized index, thus a new axis will be added""" - - if self.more_than_one_vectorized_index: - # torch indexing would remove the dimensions, we want to keep them - # as singleton dimension -> we need to add a new axis. - # inserting it in between the indices forces the dimension adeed by - # vectorized indices to be always at the beginning of the result. - self.more_than_one_vectorized_index = True - new_fancy_index = [] - for idx in fancy_index: - new_fancy_index.append(idx) - if isinstance(idx, torch.Tensor): - new_fancy_index.append(None) - if isinstance(idx, tuple): - new_fancy_index.append(None) - fancy_index = new_fancy_index - - elif vectorized_shape is not None and len(vectorized_shape) != 1: - # for a single and vectorized index, torch would insert it at the same position - # this would shift the other axes, potentially causing violations of the shape invariants. - # thus, we move the inserted axis to the beginning of the tensor, after axes inserted by None - move_source_start = next(i for i, idx in enumerate(fancy_index) if isinstance(idx, torch.Tensor)) - move_source = tuple(range(move_source_start, move_source_start + len(vectorized_shape))) - move_target_start = next(i for i, idx in enumerate(fancy_index) if idx is not None) - move_target = tuple(range(move_target_start, move_target_start + len(vectorized_shape))) - self.move_axes = (move_source, move_target) - # keep a singleton axes at the indexed axis - fancy_index.insert(move_source_start + 1, None) - - self.fancy_index = tuple(fancy_index) if has_fancy_index else () - self.normal_index = tuple(normal_index) - self.shape = shape - - def __call__(self, tensor: torch.Tensor) -> torch.Tensor: - """Apply the index to a tensor.""" - try: - tensor = tensor.broadcast_to(self.shape) - except RuntimeError: - raise IndexError('Tensor cannot be broadcasted to the desired shape') from None - - tensor = tensor[self.normal_index] # will always be a view - - if not self.fancy_index: - # nothing more to do - tensor = reduce_view(tensor) - return tensor - - # we need to modify the fancy index to efficiently handle broadcasted dimensions - fancy_index: list[None | tuple[int, ...] | torch.Tensor | slice] = [] - tensor_index = 0 - stride = tensor.stride() - for idx in self.fancy_index: - if idx is None: - fancy_index.append(idx) - # don't increment tensor_index as this is a new axis - continue - if stride[tensor_index] == 0: - # broadcasted dimension - if isinstance(idx, slice): # can only be slice(None) here - # collapse broadcasted dimensions to singleton, i.e. keep the dimension - fancy_index.append(slice(0, 1)) - elif not self.more_than_one_vectorized_index and isinstance(idx, tuple): - # as the dimension only exists due to broadcasting, it should be reduced to singleton - # there is already a None inserted after the index, so we don't need to keep the dimension - fancy_index.append((0,)) - elif not self.more_than_one_vectorized_index and isinstance(idx, torch.Tensor): - # same, but with more dimensions in the single vectorized index - # these axes will later be moved to the beginning of the tensor, as they would - # if the dimensions were not broadcasted - fancy_index.append(idx.new_zeros([1] * idx.ndim)) - else: - fancy_index.append(idx) - else: - fancy_index.append(idx) - tensor_index += 1 - - tensor = tensor[fancy_index] - - if self.move_axes[0]: - # handle the special case of a single and integer index, where we need to move the new - # axis to the beginning of the tensor - tensor = tensor.moveaxis(self.move_axes[0], self.move_axes[1]) - - return tensor diff --git a/src/mrpro/utils/reduce_repeat.py b/src/mrpro/utils/remove_repeat.py similarity index 50% rename from src/mrpro/utils/reduce_repeat.py rename to src/mrpro/utils/remove_repeat.py index 243f03dd8..85555c6f9 100644 --- a/src/mrpro/utils/reduce_repeat.py +++ b/src/mrpro/utils/remove_repeat.py @@ -1,33 +1,24 @@ """remove_repeat utility function.""" -from collections.abc import Sequence - import torch -def reduce_repeat(tensor: torch.Tensor, tol: float = 1e-6, dim: Sequence[int] | None = None) -> torch.Tensor: +def remove_repeat(tensor: torch.Tensor, tol: float) -> torch.Tensor: """Replace dimensions with all equal values with singletons. Parameters ---------- - tensor - Input tensor - tol - tolerance. - dim - dimensions to try to reduce to singletons. `None` means all. + tensor: + Input tensor, must be real. + tol: + tolerance """ - if tensor.is_complex(): - real = reduce_repeat(tensor.real, tol, dim) - imag = reduce_repeat(tensor.imag, tol, dim) - return real + 1j * imag def can_be_singleton(dim: int) -> bool: # If the distance between min and max is smaller than the tolerance, all values are the same. return bool(torch.all((tensor.amax(dim=dim) - tensor.amin(dim=dim)) <= tol).item()) - dims = dim if dim is not None else range(tensor.ndim) take_first = slice(0, 1) take_all = slice(None) - index = tuple(take_first if can_be_singleton(dim) else take_all for dim in dims) + index = tuple(take_first if can_be_singleton(dim) else take_all for dim in range(tensor.ndim)) return tensor[index] diff --git a/tests/data/test_reducerepeatmixin.py b/tests/data/test_reducerepeatmixin.py deleted file mode 100644 index 95fa4e907..000000000 --- a/tests/data/test_reducerepeatmixin.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Test the ReduceRepeatMixin.""" - -import dataclasses - -import torch -from mrpro.data import ReduceRepeatMixin, Rotation, SpatialDimension - -from tests import RandomGenerator - - -@dataclasses.dataclass -class Dummy(ReduceRepeatMixin): - a: torch.Tensor - b: SpatialDimension - c: Rotation - - def __post_init__(self) -> None: - self.a.ravel()[0] = 1.0 - - -def test_reducerepeatmixin() -> None: - """Test ReduceRepeatMixin.""" - rng = RandomGenerator(10) - - a = rng.float32_tensor((5, 1, 1, 1)) - a_expanded = a.expand(5, 2, 3, 1) - - b = SpatialDimension(*rng.float32_tensor((3, 1, 1, 3))) - b_expanded = SpatialDimension(*[x.expand(1, 2, 3) for x in b.zyx]) - - c_matrix = torch.eye(3).reshape(1, 1, 3, 3) - c_expanded = Rotation.from_matrix(c_matrix.expand(5, 2, 3, 3)) - - test = Dummy(a_expanded, b_expanded, c_expanded) - - torch.testing.assert_close(test.a, a) - torch.testing.assert_close(test.b.z, b.z) - torch.testing.assert_close(test.b.y, b.y) - torch.testing.assert_close(test.b.x, b.x) - torch.testing.assert_close(test.c.as_matrix(), c_matrix) - - assert test.a[0, 0, 0, 0] == 1.0, 'subclass post_init not called' diff --git a/tests/utils/test_indexer.py b/tests/utils/test_indexer.py deleted file mode 100644 index 6d89019f4..000000000 --- a/tests/utils/test_indexer.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch -from mrpro.utils.indexing import Indexer - - -@pytest.mark.parametrize( - ('shape', 'broadcast_shape', 'index', 'expected_shape'), - [ - ((1, 6, 7), (5, 6, 7), (slice(None), torch.ones(4, 5).int(), slice(None)), (4, 5, 1, 1, 7)), # array index - ((1, 6, 7), (5, 6, 7), (slice(None), slice(None), slice(None)), (1, 6, 7)), # nothing - ((1, 6, 7), (5, 6, 7), (), (1, 6, 7)), # nothing - ((5, 6, 1), (5, 6, 7), (1,), (1, 6, 1)), # integer indexing - ((5, 1, 1), (5, 6, 7), (slice(None), 2), (5, 1, 1)), # integer indexing broadcast - ((1, 1, 7), (5, 6, 7), (slice(1, 2), slice(None, None, 2), slice(1, None, 2)), (1, 1, 3)), # slices - ((5, 1, 1), (5, 6, 7), (torch.tensor([0, 2]),), (2, 1, 1)), # array index - ((1, 6, 1), (5, 6, 7), (slice(None), [0, 2]), (1, 2, 1)), # integer list - ((1, 1, 1), (5, 6, 7), (torch.tensor([0, 2]), slice(None), slice(None)), (1, 1, 1)), # array index broadcast - ((5, 1, 1), (5, 6, 7), (torch.tensor([0, 2]), slice(None), (0, 2)), (2, 1, 1, 1)), # two array indicces - ((5, 6, 7), (5, 6, 7), (slice(None), torch.tensor([1, 3]), slice(None)), (5, 2, 7)), - ((5, 6, 7), (5, 6, 7), (slice(None), slice(None), torch.tensor([0, 5])), (5, 6, 2)), - ((5, 6, 7), (5, 6, 7), (torch.tensor([True, False, True, False, True]), slice(None), slice(None)), (3, 6, 7)), - ((5, 6, 7), (5, 6, 7), (torch.ones(5, 6, 7).bool(),), (210, 1, 1, 1)), - ( - (5, 6, 7), - (5, 6, 7), - (torch.tensor([True, False, True, True, False]), slice(None), slice(None)), - (3, 6, 7), - ), - ((5, 6, 7), (5, 6, 7), (torch.ones(5).bool(), ...), (5, 6, 7)), - ((5, 6, 7), (5, 6, 7), (torch.tensor([[0, 1], [2, 3]]), slice(None), slice(None)), (2, 2, 1, 6, 7)), - ((5, 1, 7), (5, 6, 7), (slice(None), torch.tensor([[1, 2], [3, 4]]), slice(None)), (1, 1, 5, 1, 7)), - ((5, 6, 7), (5, 6, 7), (slice(None), torch.tensor([True, False, True, True, False, True]), 0), (5, 4, 1)), - ((5, 6, 7), (5, 6, 7), (torch.tensor([1, 3]), slice(None), torch.tensor([0, 5])), (2, 1, 6, 1)), - ((5, 6, 7), (5, 6, 7), (None, slice(None), slice(None)), (1, 5, 6, 7)), - ((5, 6, 7), (5, 6, 7), (None, None, slice(None), slice(None)), (1, 1, 5, 6, 7)), - ((5, 6, 7), (5, 6, 7), (..., 0), (5, 6, 1)), - ((5, 6, 7), (5, 6, 7), (slice(None), ..., slice(None)), (5, 6, 7)), - ((5, 6, 7), (5, 6, 7), (1, ..., 3), (1, 6, 1)), - ((5, 1, 7), (5, 6, 7), (torch.ones(1, 1, 1, dtype=torch.bool),), (5, 1, 7)), - ((5, 1, 1), (5, 6, 7), (torch.ones(1, 6, 1, dtype=torch.bool),), (5, 1, 1)), - ], -) -def test_indexer(shape, broadcast_shape, index, expected_shape): - tensor = torch.arange(int(torch.prod(torch.tensor(shape)))).reshape(shape) - indexer = Indexer(broadcast_shape, index) - result = indexer(tensor) - assert result.shape == expected_shape - - -@pytest.mark.parametrize( - ('index', 'error_message'), - [ - # Type errors - ('invalid_index', 'Unsupported index type'), - (torch.tensor([1.0]), 'Unsupported index dtype'), - ((Ellipsis, Ellipsis), 'Only one ellipsis is allowed'), - ((slice(None), None), 'New axes are only allowed at the beginning'), - ((slice(None, None, -1),), 'Negative step size for slices is not supported'), - ((0, 1, 2, 3, 4, 5), 'Too many indices'), - ((5,), 'Index 5 out of bounds'), - ((torch.tensor([10]),), 'Index out of bounds'), - (([10],), r'Index out of bounds. Got values in the interval \[10, 11\)'), - ((torch.ones(3, 1, 6, dtype=torch.bool),), 'Boolean index has wrong shape'), - (([0, 1], [0, 1, 2]), 'All vectorized indices must have the same shape'), - (([0, 1], torch.zeros(2, 1).int()), 'All vectorized indices must have the same shape'), - ((torch.tensor([True, False]), torch.tensor([True, False])), 'Only one boolean index is allowed'), - ], -) -def test_indexer_invalid_indexing(index, error_message): - """Test various invalid indexing scenarios.""" - shape = (3, 4, 5) - with pytest.raises(IndexError, match=error_message): - Indexer(shape, index) - - -def test_indexer_broadcast_error(): - """Test error when tensor cannot be broadcast to target shape.""" - shape = (3, 4, 5) - tensor = torch.arange(24).reshape(2, 3, 4) - indexer = Indexer(shape, (slice(None),) * 3) - - with pytest.raises(IndexError, match='cannot be broadcasted'): - indexer(tensor)