Skip to content

Commit

Permalink
Ensure compressed number of coils is not greater than existing coils
Browse files Browse the repository at this point in the history
ghstack-source-id: 259899853d246003ca2223e6756b3f36d6841807
ghstack-comment-id: 2506408787
Pull Request resolved: #567
  • Loading branch information
fzimmermann89 committed Dec 16, 2024
1 parent 95859c4 commit 53103a5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/mrpro/data/KData.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,13 @@ def compress_coils(
from mrpro.operators import PCACompressionOp

coil_dim = -4 % self.data.ndim

if n_compressed_coils > (n_current_coils := self.data.shape[coil_dim]):
raise ValueError(
f'Number of compressed coils ({n_compressed_coils}) cannot be greater '
f'than the number of current coils ({n_current_coils}).'
)

if batch_dims is not None and joint_dims is not Ellipsis:
raise ValueError('Either batch_dims or joint_dims can be defined not both.')

Expand All @@ -347,22 +354,20 @@ def compress_coils(

# reshape to (*batch dimension, -1, coils)
permute_order = (
batch_dims_normalized
+ [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized]
+ [coil_dim]
*batch_dims_normalized,
*[i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized],
coil_dim,
)
kdata_coil_compressed = self.data.permute(permute_order)
permuted_kdata_shape = kdata_coil_compressed.shape
kdata_coil_compressed = kdata_coil_compressed.flatten(
kdata_permuted = self.data.permute(permute_order)
kdata_flattened = kdata_permuted.flatten(
start_dim=len(batch_dims_normalized), end_dim=-2
) # keep separate dimensions and coil

pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils)
(kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed)

pca_compression_op = PCACompressionOp(data=kdata_flattened, n_components=n_compressed_coils)
(kdata_coil_compressed_flattened,) = pca_compression_op(kdata_flattened)
del kdata_flattened
# reshape to original dimensions and undo permutation
kdata_coil_compressed = torch.reshape(
kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils]
kdata_coil_compressed_flattened, [*kdata_permuted.shape[:-1], n_compressed_coils]
).permute(*np.argsort(permute_order))

return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone())
7 changes: 7 additions & 0 deletions tests/data/test_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,10 @@ def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata):

with pytest.raises(ValueError, match='Coil dimension must not'):
consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,))


def test_KData_compress_coils_error_n_coils(consistently_shaped_kdata):
"""Test if error is raised if new coils would be larger than existing coils"""
existing_coils = consistently_shaped_kdata.data.shape[-4]
with pytest.raises(ValueError, match='greater'):
consistently_shaped_kdata.compress_coils(existing_coils + 1)

0 comments on commit 53103a5

Please sign in to comment.