Skip to content

Commit

Permalink
Merge branch 'main' into cd_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
schuenke authored Oct 10, 2024
2 parents 0b42e22 + b4f0e5b commit 20372ff
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/csm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from mrpro.algorithms.csm.iterative_walsh import iterative_walsh
from mrpro.algorithms.csm.walsh import walsh
from mrpro.algorithms.csm.inati import inati
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
"""Iterative Walsh method for coil sensitivity map calculation."""
"""(Iterative) Walsh method for coil sensitivity map calculation."""

import torch

from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils.filters import uniform_filter


def iterative_walsh(
coil_images: torch.Tensor,
smoothing_width: SpatialDimension[int] | int,
power_iterations: int,
) -> torch.Tensor:
def walsh(coil_images: torch.Tensor, smoothing_width: SpatialDimension[int] | int) -> torch.Tensor:
"""Calculate a coil sensitivity map (csm) using an iterative version of the Walsh method.
This is for a single set of coil images. The input should be a tensor with dimensions
Expand All @@ -26,13 +22,14 @@ def iterative_walsh(
images for each coil element
smoothing_width
width of the smoothing filter
power_iterations
number of iterations used to determine dominant eigenvector
References
----------
.. [WAL2000] Walsh DO, Gmitro AF, Marcellin MW (2000) Adaptive reconstruction of phased array MR imagery. MRM 43
"""
# After 10 power iterations we will have a very good estimate of the singular vector
n_power_iterations = 10

if isinstance(smoothing_width, int):
smoothing_width = SpatialDimension(smoothing_width, smoothing_width, smoothing_width)
# Compute the pointwise covariance between coils
Expand All @@ -44,7 +41,7 @@ def iterative_walsh(
# At each point in the image, find the dominant eigenvector
# of the signal covariance matrix using the power method
v = coil_covariance.sum(dim=0)
for _ in range(power_iterations):
for _ in range(n_power_iterations):
v /= v.norm(dim=0)
v = torch.einsum('abzyx,bzyx->azyx', coil_covariance, v)
csm = v / v.norm(dim=0)
Expand Down
7 changes: 2 additions & 5 deletions src/mrpro/data/CsmData.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def from_idata_walsh(
cls,
idata: IData,
smoothing_width: int | SpatialDimension[int] = 5,
power_iterations: int = 3,
chunk_size_otherdim: int | None = None,
) -> Self:
"""Create csm object from image data using iterative Walsh method.
Expand All @@ -33,20 +32,18 @@ def from_idata_walsh(
IData object containing the images for each coil element.
smoothing_width
width of smoothing filter.
power_iterations
number of iterations used to determine dominant eigenvector
chunk_size_otherdim:
How many elements of the other dimensions should be processed at once.
Default is None, which means that all elements are processed at once.
"""
from mrpro.algorithms.csm.iterative_walsh import iterative_walsh
from mrpro.algorithms.csm.walsh import walsh

# convert smoothing_width to SpatialDimension if int
if isinstance(smoothing_width, int):
smoothing_width = SpatialDimension(smoothing_width, smoothing_width, smoothing_width)

csm_fun = torch.vmap(
lambda img: iterative_walsh(img, smoothing_width, power_iterations),
lambda img: walsh(img, smoothing_width),
chunk_size=chunk_size_otherdim,
)
csm_tensor = csm_fun(idata.data)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""Tests the iterative Walsh algorithm."""

import torch
from mrpro.algorithms.csm import iterative_walsh
from mrpro.algorithms.csm import walsh
from mrpro.data import SpatialDimension
from tests.algorithms.csm.conftest import multi_coil_image
from tests.helper import relative_image_difference


def test_iterative_Walsh(ellipse_phantom, random_kheader):
"""Test the iterative Walsh method."""
def test_walsh(ellipse_phantom, random_kheader):
"""Test the Walsh method."""
idata, csm_ref = multi_coil_image(n_coils=4, ph_ellipse=ellipse_phantom, random_kheader=random_kheader)

# Estimate coil sensitivity maps.
# iterative_walsh should be applied for each other dimension separately
# walsh should be applied for each other dimension separately
smoothing_width = SpatialDimension(z=1, y=5, x=5)
csm = iterative_walsh(idata.data[0, ...], smoothing_width, power_iterations=3)
csm = walsh(idata.data[0, ...], smoothing_width)

# Phase is only relative in csm calculation, therefore only the abs values are compared.
assert relative_image_difference(torch.abs(csm), torch.abs(csm_ref[0, ...])) <= 0.01
2 changes: 1 addition & 1 deletion tests/data/test_csm_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from mrpro.data import CsmData, SpatialDimension

from tests.algorithms.csm.test_iterative_walsh import multi_coil_image
from tests.algorithms.csm.test_walsh import multi_coil_image
from tests.helper import relative_image_difference


Expand Down

0 comments on commit 20372ff

Please sign in to comment.