From 96184659d74b7546656dbde6f59cdf99ad11173c Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Thu, 24 Oct 2024 09:11:51 +0200 Subject: [PATCH 01/23] Add pre-commit check for __all__ (#471) --- .pre-commit-config.yaml | 8 ++++++ src/mrpro/__init__.py | 10 +++++-- src/mrpro/algorithms/__init__.py | 2 +- src/mrpro/algorithms/csm/__init__.py | 2 +- .../algorithms/reconstruction/__init__.py | 6 ++++- src/mrpro/data/__init__.py | 26 ++++++++++++++++++- src/mrpro/data/traj_calculators/__init__.py | 10 ++++++- src/mrpro/operators/__init__.py | 9 ++++--- src/mrpro/operators/models/__init__.py | 10 ++++++- src/mrpro/phantoms/__init__.py | 2 +- src/mrpro/utils/__init__.py | 13 +++++++++- 11 files changed, 85 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b14c4344b..7803441eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,13 @@ repos: hooks: - id: typos + - repo: https://github.com/fzimmermann89/check_all + rev: v1.0 + hooks: + - id: check-init-all + args: [--double-quotes] + exclude: ^tests/ + - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.11.2 hooks: @@ -45,6 +52,7 @@ repos: - xsdata - "--index-url=https://download.pytorch.org/whl/cpu" - "--extra-index-url=https://pypi.python.org/simple" + ci: autofix_commit_msg: | [pre-commit] auto fixes from pre-commit hooks diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 4a0ac4ca0..729ae188c 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,4 +1,10 @@ from mrpro._version import __version__ from mrpro import algorithms, operators, data, phantoms, utils -__all__ = ["algorithms", "operators", "data", "phantoms", "utils"] - +__all__ = [ + "__version__", + "algorithms", + "data", + "operators", + "phantoms", + "utils" +] diff --git a/src/mrpro/algorithms/__init__.py b/src/mrpro/algorithms/__init__.py index 6a5d8cd64..3dbc8ed07 100644 --- a/src/mrpro/algorithms/__init__.py +++ b/src/mrpro/algorithms/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms import csm, optimizers, reconstruction from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -__all__ = ["csm", "optimizers", "reconstruction", "prewhiten_kspace"] +__all__ = ["csm", "optimizers", "prewhiten_kspace", "reconstruction"] \ No newline at end of file diff --git a/src/mrpro/algorithms/csm/__init__.py b/src/mrpro/algorithms/csm/__init__.py index 499984f53..255e09a89 100644 --- a/src/mrpro/algorithms/csm/__init__.py +++ b/src/mrpro/algorithms/csm/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms.csm.walsh import walsh from mrpro.algorithms.csm.inati import inati -__all__ = ["walsh", "inati"] +__all__ = ["inati", "walsh"] \ No newline at end of file diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 1d292cf37..180186dd4 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,4 +1,8 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction -__all__ = ["Reconstruction", "DirectReconstruction", "IterativeSENSEReconstruction"] +__all__ = [ + "DirectReconstruction", + "IterativeSENSEReconstruction", + "Reconstruction" +] \ No newline at end of file diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index a954ca1b3..b5034d668 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -17,4 +17,28 @@ from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension from mrpro.data.TrajectoryDescription import TrajectoryDescription -__all__ = ["enums", "traj_calculators", "acq_filters", "AcqIdx", "AcqInfo", "CsmData", "Data", "DcfData", "EncodingLimits", "Limits", "IData", "IHeader", "KData", "KHeader", "KNoise", "KTrajectory", "KTrajectoryRawShape", "MoveDataMixin", "QData", "QHeader", "Rotation", "SpatialDimension", "TrajectoryDescription"] +__all__ = [ + "AcqIdx", + "AcqInfo", + "CsmData", + "Data", + "DcfData", + "EncodingLimits", + "IData", + "IHeader", + "KData", + "KHeader", + "KNoise", + "KTrajectory", + "KTrajectoryRawShape", + "Limits", + "MoveDataMixin", + "QData", + "QHeader", + "Rotation", + "SpatialDimension", + "TrajectoryDescription", + "acq_filters", + "enums", + "traj_calculators" +] \ No newline at end of file diff --git a/src/mrpro/data/traj_calculators/__init__.py b/src/mrpro/data/traj_calculators/__init__.py index 72d61536b..2fd0e5b5a 100644 --- a/src/mrpro/data/traj_calculators/__init__.py +++ b/src/mrpro/data/traj_calculators/__init__.py @@ -5,4 +5,12 @@ from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd from mrpro.data.traj_calculators.KTrajectoryPulseq import KTrajectoryPulseq from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian -__all__ = ["KTrajectoryCalculator", "KTrajectoryRpe", "KTrajectorySunflowerGoldenRpe", "KTrajectoryRadial2D", "KTrajectoryIsmrmrd", "KTrajectoryPulseq", "KTrajectoryCartesian"] +__all__ = [ + "KTrajectoryCalculator", + "KTrajectoryCartesian", + "KTrajectoryIsmrmrd", + "KTrajectoryPulseq", + "KTrajectoryRadial2D", + "KTrajectoryRpe", + "KTrajectorySunflowerGoldenRpe" +] \ No newline at end of file diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 4fe58f1e3..345c337fa 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -23,8 +23,6 @@ from mrpro.operators.ZeroOp import ZeroOp __all__ = [ - "functionals", - "models", "CartesianSamplingOp", "ConstraintsOp", "DensityCompensationOp", @@ -39,14 +37,19 @@ "IdentityOp", "LinearOperator", "MagnitudeOp", + "MultiIdentityOp", "Operator", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", + "ScaledFunctional", + "ScaledProximableFunctional", "SensitivityOp", "SignalModel", "SliceProjectionOp", "WaveletOp", "ZeroOp", "ZeroPadOp", -] + "functionals", + "models" +] \ No newline at end of file diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index 0ec3cc517..629a560d3 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -5,4 +5,12 @@ from mrpro.operators.models.WASABITI import WASABITI from mrpro.operators.models.MonoExponentialDecay import MonoExponentialDecay from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation -__all__ = ["SaturationRecovery", "InversionRecovery", "MOLLI", "WASABI", "WASABITI", "MonoExponentialDecay", "TransientSteadyStateWithPreparation"] +__all__ = [ + "InversionRecovery", + "MOLLI", + "MonoExponentialDecay", + "SaturationRecovery", + "TransientSteadyStateWithPreparation", + "WASABI", + "WASABITI" +] \ No newline at end of file diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index 3c9046364..081d2e465 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -1,3 +1,3 @@ from mrpro.phantoms.EllipsePhantom import EllipsePhantom from mrpro.phantoms.phantom_elements import EllipseParameters -__all__ = ["EllipsePhantom", "EllipseParameters"] +__all__ = ["EllipseParameters", "EllipsePhantom"] \ No newline at end of file diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index ce3daf2c8..c09071f4b 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,4 +6,15 @@ from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right -__all__ = ["slice_profiles", "typing", "smap", "remove_repeat", "zero_pad_or_crop", "modify_acq_info", "split_idx", "broadcast_right", "unsqueeze_left", "unsqueeze_right"] +__all__ = [ + "broadcast_right", + "modify_acq_info", + "remove_repeat", + "slice_profiles", + "smap", + "split_idx", + "typing", + "unsqueeze_left", + "unsqueeze_right", + "zero_pad_or_crop" +] \ No newline at end of file From 9c28a48aeed9dafe6a249a5ddd1d5b62581e9ac0 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:41:06 +0200 Subject: [PATCH 02/23] Fix PyTest and Deployment workflow triggers (#467) --- .github/workflows/deployment.yml | 49 +++++++++++++++++++------------- .github/workflows/pytest.yml | 48 ++++++++++++++++--------------- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index 9f02e8dcc..dd900496e 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -1,4 +1,4 @@ -name: Publish to PyPI +name: Build and Deploy Package on: push: @@ -7,21 +7,31 @@ on: jobs: build-testpypi-package: - name: Modify version and build package for TestPyPI + name: Build Package for TestPyPI runs-on: ubuntu-latest outputs: version: ${{ steps.set_suffix.outputs.version }} suffix: ${{ steps.set_suffix.outputs.suffix }} + version_changed: ${{ steps.changes.outputs.version_changed }} steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.x" - - name: Set environment variable for version suffix + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Check if VERSION file is modified compared to main + uses: dorny/paths-filter@v3 + id: changes + with: + base: main + filters: | + version_changed: + - 'src/mrpro/VERSION' + + - name: Set Version Suffix id: set_suffix run: | VERSION=$(cat src/mrpro/VERSION) @@ -30,12 +40,12 @@ jobs: echo "suffix=$SUFFIX" >> $GITHUB_OUTPUT echo "version=$VERSION" >> $GITHUB_OUTPUT - - name: Build package with version suffix + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store TestPyPi distribution + - name: Upload TestPyPI Distribution Artifact uses: actions/upload-artifact@v4 with: name: testpypi-package-distribution @@ -46,6 +56,7 @@ jobs: needs: - build-testpypi-package runs-on: ubuntu-latest + if: needs.build-testpypi-package.outputs.version_changed == 'true' environment: name: testpypi @@ -55,7 +66,7 @@ jobs: id-token: write steps: - - name: Download TestPyPi distribution + - name: Download TestPyPI Distribution uses: actions/download-artifact@v4 with: name: testpypi-package-distribution @@ -68,7 +79,7 @@ jobs: verbose: true test-install-from-testpypi: - name: Test installation from TestPyPI + name: Test Installation from TestPyPI needs: - testpypi-deployment - build-testpypi-package @@ -86,14 +97,14 @@ jobs: python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ build-pypi-package: - name: Build package for PyPI + name: Build Package for PyPI runs-on: ubuntu-latest needs: - test-install-from-testpypi outputs: version: ${{ steps.get_version.outputs.version }} steps: - - name: Checkout code + - name: Checkout Code uses: actions/checkout@v4 - name: Set up Python @@ -101,22 +112,22 @@ jobs: with: python-version: "3.x" - - name: Install setuptools_git_versioning + - name: Install Automatic Versioning Tool run: | python -m pip install setuptools-git-versioning - - name: Get current version + - name: Get Current Version id: get_version run: | VERSION=$(python -m setuptools_git_versioning) echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - - name: Build package + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store PyPi distribution + - name: Store PyPI Distribution uses: actions/upload-artifact@v4 with: name: pypi-package-distribution @@ -138,13 +149,13 @@ jobs: id-token: write steps: - - name: Download PyPi distribution + - name: Download PyPI Distribution uses: actions/download-artifact@v4 with: name: pypi-package-distribution path: dist/ - - name: Create tag + - name: Create Tag uses: actions/github-script@v7 with: script: | @@ -155,7 +166,7 @@ jobs: sha: context.sha }) - - name: Create release + - name: Create Release uses: actions/github-script@v7 with: github-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 427b62e96..53d1343c7 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -2,17 +2,21 @@ name: PyTest on: pull_request: + push: + branches: + - main jobs: get_dockerfiles: - name: Get list of dockerfiles for different containers + name: Get List of Dockerfiles for Containers runs-on: ubuntu-latest permissions: packages: read outputs: imagenames: ${{ steps.set-matrix.outputs.imagenames }} steps: - - id: set-matrix + - name: Retrieve Docker Image Names + id: set-matrix env: GH_TOKEN: ${{ secrets.GHCR_TOKEN }} run: | @@ -35,58 +39,56 @@ jobs: echo "image names with tag latest: $imagenames_latest" echo "imagenames=$imagenames_latest" >> $GITHUB_OUTPUT - - name: Dockerfile overview + - name: Dockerfile Overview run: | - echo "final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" + echo "Final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" test: - name: Run tests and get coverage report + name: Run Tests and Coverage Report needs: get_dockerfiles runs-on: ubuntu-latest permissions: pull-requests: write contents: write strategy: - fail-fast: false + fail-fast: false matrix: imagename: ${{ fromJson(needs.get_dockerfiles.outputs.imagenames) }} - # runs within Docker container container: image: ghcr.io/ptb-mr/${{ matrix.imagename }}:latest options: --user runner - steps: - - name: Checkout repo + - name: Checkout Repository uses: actions/checkout@v4 - - name: Install mrpro and dependencies - run: pip install --upgrade --upgrade-strategy "eager" .[test] + - name: Install MRpro and Dependencies + run: pip install --upgrade --upgrade-strategy eager .[test] - - name: Install pytest-github-actions-annotate-failures plugin + - name: Install PyTest GitHub Annotation Plugin run: pip install pytest-github-actions-annotate-failures - - name: Run PyTest + - name: Run PyTest and Generate Coverage Report run: | - pytest -n 4 -m "not cuda" --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt + pytest -n 4 -m "not cuda" --junitxml=pytest.xml \ + --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt - - name: Check for pytest.xml + - name: Verify PyTest XML Output run: | - if [ -f pytest.xml ]; then - echo "pytest.xml file found. Continuing..." - else - echo "pytest.xml file not found. Please check previous 'Run PyTest' section for errors." + if [ ! -f pytest.xml ]; then + echo "PyTest XML report not found. Please check the previous 'Run PyTest' step for errors." exit 1 fi - - name: Pytest coverage comment + - name: Post PyTest Coverage Comment id: coverageComment uses: MishaKav/pytest-coverage-comment@v1.1.53 with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml - - name: Create the Badge + - name: Create Coverage Badge on Main Branch Push uses: schneegans/dynamic-badges-action@v1.7.0 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' with: auth: ${{ secrets.GIST_SECRET }} gistID: 48e334a10caf60e6708d7c712e56d241 @@ -96,12 +98,12 @@ jobs: color: ${{ steps.coverageComment.outputs.color }} namedLogo: python - - name: Set pipeline status + - name: Set Pipeline Status Based on Test Results if: steps.coverageComment.outputs.errors != 0 || steps.coverageComment.outputs.failures != 0 uses: actions/github-script@v7 with: script: | - core.setFailed('PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.') + core.setFailed("PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.") concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} From ec503af94aef49298b194d9a14ff476cf880930b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Thu, 24 Oct 2024 10:44:00 +0200 Subject: [PATCH 03/23] Remove superfluous .forward (#472) --- examples/pulseq_2d_radial_golden_angle.ipynb | 6 +++--- examples/pulseq_2d_radial_golden_angle.py | 6 +++--- src/mrpro/operators/FourierOp.py | 2 +- tests/operators/models/test_inversion_recovery.py | 4 ++-- tests/operators/models/test_molli.py | 4 ++-- tests/operators/models/test_mono_exponential_decay.py | 4 ++-- tests/operators/models/test_saturation_recovery.py | 4 ++-- .../test_transient_steady_state_with_preparation.py | 6 +++--- tests/operators/models/test_wasabi.py | 8 ++++---- tests/operators/models/test_wasabiti.py | 10 +++++----- tests/operators/test_sensitivity_op.py | 2 +- tests/operators/test_zero_pad_op.py | 2 +- 12 files changed, 29 insertions(+), 29 deletions(-) diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index d0e6e0adb..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -76,7 +76,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_ismrmrd_traj = direct_reconstruction.forward(kdata)" + "img_using_ismrmrd_traj = direct_reconstruction(kdata)" ] }, { @@ -100,7 +100,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_rad2d_traj = direct_reconstruction.forward(kdata)" + "img_using_rad2d_traj = direct_reconstruction(kdata)" ] }, { @@ -143,7 +143,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_pulseq_traj = direct_reconstruction.forward(kdata)" + "img_using_pulseq_traj = direct_reconstruction(kdata)" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 955b20a37..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -37,7 +37,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_ismrmrd_traj = direct_reconstruction.forward(kdata) +img_using_ismrmrd_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryRadial2D @@ -49,7 +49,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_rad2d_traj = direct_reconstruction.forward(kdata) +img_using_rad2d_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryPulseq @@ -73,7 +73,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_pulseq_traj = direct_reconstruction.forward(kdata) +img_using_pulseq_traj = direct_reconstruction(kdata) # %% [markdown] # ### Plot the different reconstructed images diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 79351da76..f4254d8d2 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """ if len(self._fft_dims): # FFT - (x,) = self._fast_fourier_op.forward(x) + (x,) = self._fast_fourier_op(x) if self._nufft_dims: # we need to move the nufft-dimensions to the end and flatten all other dimensions diff --git a/tests/operators/models/test_inversion_recovery.py b/tests/operators/models/test_inversion_recovery.py index 71bb4dfe6..52f957f8f 100644 --- a/tests/operators/models/test_inversion_recovery.py +++ b/tests/operators/models/test_inversion_recovery.py @@ -21,7 +21,7 @@ def test_inversion_recovery(ti, result): """ model = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) # Assert closeness to -m0 for ti=0 if result == '-m0': @@ -37,5 +37,5 @@ def test_inversion_recovery_shape(parameter_shape, contrast_dim_shape, signal_sh (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_molli.py b/tests/operators/models/test_molli.py index 8ee5f9117..c92f92556 100644 --- a/tests/operators/models/test_molli.py +++ b/tests/operators/models/test_molli.py @@ -27,7 +27,7 @@ def test_molli(ti, result): # Generate signal model and torch tensor for comparison model = MOLLI(ti) - (image,) = model.forward(a, c, t1) + (image,) = model(a, c, t1) # Assert closeness to a(1-c) for large ti if result == 'a(1-c)': @@ -43,5 +43,5 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape): (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MOLLI(ti) a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(a, c, t1) + (signal,) = model_op(a, c, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_mono_exponential_decay.py b/tests/operators/models/test_mono_exponential_decay.py index c56e86e12..d77d5862c 100644 --- a/tests/operators/models/test_mono_exponential_decay.py +++ b/tests/operators/models/test_mono_exponential_decay.py @@ -21,7 +21,7 @@ def test_mono_exponential_decay(decay_time, result): """ model = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples() - (image,) = model.forward(m0, decay_constant) + (image,) = model(m0, decay_constant) zeros = torch.zeros_like(m0) @@ -39,5 +39,5 @@ def test_mono_exponential_decay_shape(parameter_shape, contrast_dim_shape, signa (decay_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, decay_constant) + (signal,) = model_op(m0, decay_constant) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_saturation_recovery.py b/tests/operators/models/test_saturation_recovery.py index 1a821282b..692d4cc31 100644 --- a/tests/operators/models/test_saturation_recovery.py +++ b/tests/operators/models/test_saturation_recovery.py @@ -21,7 +21,7 @@ def test_saturation_recovery(ti, result): """ model = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) zeros = torch.zeros_like(m0) @@ -39,5 +39,5 @@ def test_saturation_recovery_shape(parameter_shape, contrast_dim_shape, signal_s (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index 87d21a336..2b9d2e451 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -23,7 +23,7 @@ def test_transient_steady_state(sampling_time, m0_scaling_preparation, result): repetition_time = 5 model = TransientSteadyStateWithPreparation(sampling_time, repetition_time, m0_scaling_preparation) m0, t1, flip_angle = create_parameter_tensor_tuples(number_of_tensors=3) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) # Assert closeness to m0 if result == 'm0': @@ -56,7 +56,7 @@ def test_transient_steady_state_inversion_recovery(): analytical_signal = m0 * (1 - 2 * torch.exp(-(sampling_time / t1))) model = TransientSteadyStateWithPreparation(sampling_time, repetition_time=100, m0_scaling_preparation=-1) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) torch.testing.assert_close(signal, analytical_signal) @@ -77,5 +77,5 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation ) m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(m0, t1, flip_angle) + (signal,) = model_op(m0, t1, flip_angle) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 1aa36384e..6b4494539 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -14,11 +14,11 @@ def test_WASABI_shift(): """Test symmetry property of shifted and unshifted WASABI spectra.""" offsets_unshifted, b0_shift, rb1, c, d = create_data() wasabi_model = WASABI(offsets=offsets_unshifted) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) offsets_shifted, b0_shift, rb1, c, d = create_data(b0_shift=100) wasabi_model = WASABI(offsets=offsets_shifted) - (signal_shifted,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal_shifted,) = wasabi_model(b0_shift, rb1, c, d) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -30,7 +30,7 @@ def test_WASABI_shift(): def test_WASABI_extreme_offset(): offset, b0_shift, rb1, c, d = create_data(offset_max=30000, n_offsets=1) wasabi_model = WASABI(offsets=offset) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) assert torch.isclose(signal, torch.tensor([1.0])), 'For an extreme offset, the signal should be unattenuated' @@ -41,5 +41,5 @@ def test_WASABI_shape(parameter_shape, contrast_dim_shape, signal_shape): (offsets,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = WASABI(offsets) b0_shift, rb1, c, d = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=4) - (signal,) = model_op.forward(b0_shift, rb1, c, d) + (signal,) = model_op(b0_shift, rb1, c, d) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index bd451a03b..de2cacc50 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -15,7 +15,7 @@ def test_WASABITI_symmetry(): """Test symmetry property of complete WASABITI spectra.""" offsets, b0_shift, rb1, t1 = create_data() wasabiti_model = WASABITI(offsets=offsets, trec=torch.ones_like(offsets)) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) # check that all values are symmetric around the center assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center' @@ -26,7 +26,7 @@ def test_WASABITI_symmetry_after_shift(): offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100) trec = torch.ones_like(offsets_shifted) wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) - (signal_shifted,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -42,7 +42,7 @@ def test_WASABITI_asymmetry_for_non_unique_trec(): trec[: len(offsets_unshifted) // 2] = 2.0 wasabiti_model = WASABITI(offsets=offsets_unshifted, trec=trec) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center' @@ -53,7 +53,7 @@ def test_WASABITI_relaxation_term(t1): offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) trec = torch.ones_like(offset) * t1 wasabiti_model = WASABITI(offsets=offset, trec=trec) - sig = wasabiti_model.forward(b0_shift, rb1, t1) + sig = wasabiti_model(b0_shift, rb1, t1) assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8) @@ -72,5 +72,5 @@ def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) model_op = WASABITI(ti, trec) b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(b0_shift, rb1, t1) + (signal,) = model_op(b0_shift, rb1, t1) assert signal.shape == signal_shape diff --git a/tests/operators/test_sensitivity_op.py b/tests/operators/test_sensitivity_op.py index 25efc5bee..1321498d8 100644 --- a/tests/operators/test_sensitivity_op.py +++ b/tests/operators/test_sensitivity_op.py @@ -90,7 +90,7 @@ def test_sensitivity_op_other_dim_compatibility_fail(n_other_csm, n_other_img): # Apply to n_other_img shape u = random_generator.complex64_tensor(size=(n_other_img, 1, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): - sensitivity_op.forward(u) + sensitivity_op(u) v = random_generator.complex64_tensor(size=(n_other_img, n_coils, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index a56635f5f..ce2d8855d 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -20,7 +20,7 @@ def test_zero_pad_op_content(): original_shape=tuple([original_shape[d] for d in padding_dimensions]), padded_shape=tuple([padded_shape[d] for d in padding_dimensions]), ) - (padded_data,) = zero_padding_op.forward(original_data) + (padded_data,) = zero_padding_op(original_data) # Compare overlapping region torch.testing.assert_close(original_data[:, 10:90, :, 50:150, :, :], padded_data[:, :, :, :, 95:145, :]) From eecdf490f6bb11c147de8e680f8581e23156fd47 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Thu, 24 Oct 2024 13:13:32 +0200 Subject: [PATCH 04/23] Remove superfluous .forward (#475) Co-authored-by: Felix Zimmermann <fzimmermann89@gmail.com> --- src/mrpro/operators/FastFourierOp.py | 2 +- tests/operators/test_fast_fourier_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 53f3f6eb4..4ffe7e3a6 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: FFT of x """ y = torch.fft.fftshift( - torch.fft.fftn(torch.fft.ifftshift(*self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + torch.fft.fftn(torch.fft.ifftshift(*self._pad_op(x), dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim, ) return (y,) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 7ac1a8211..7e2f47fd7 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -28,7 +28,7 @@ def test_fast_fourier_op_forward(npoints, a): # Transform image to k-space ff_op = FastFourierOp(dim=(0,)) - (igauss_fwd,) = ff_op.forward(igauss) + (igauss_fwd,) = ff_op(igauss) # Scaling to "undo" fft scaling igauss_fwd *= np.sqrt(npoints) / 2 From bbdc97434828780d6ae7016d27353361e4d35ad1 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Sat, 26 Oct 2024 19:53:53 +0200 Subject: [PATCH 05/23] Add RegularizedIterativeSENSEReconstruction (#388) Co-authored-by: NAME <Pierrick.bouilloux2002@gmail.com> --- ...rized_iterative_sense_reconstruction.ipynb | 389 ++++++++++++++++++ ...ularized_iterative_sense_reconstruction.py | 193 +++++++++ .../IterativeSENSEReconstruction.py | 87 +--- .../reconstruction/Reconstruction.py | 4 +- ...RegularizedIterativeSENSEReconstruction.py | 139 +++++++ .../algorithms/reconstruction/__init__.py | 6 +- 6 files changed, 733 insertions(+), 85 deletions(-) create mode 100644 examples/regularized_iterative_sense_reconstruction.ipynb create mode 100644 examples/regularized_iterative_sense_reconstruction.py create mode 100644 src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb new file mode 100644 index 000000000..6b1c2704b --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af432293", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Regularized Iterative SENSE Reconstruction of 2D golden angle radial data\n", + "Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7a6ce3", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# define zenodo URL of the example ismrmd data\n", + "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd8486b", + "metadata": {}, + "outputs": [], + "source": [ + "# Download raw data\n", + "import tempfile\n", + "\n", + "import requests\n", + "\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "6a9defa1", + "metadata": {}, + "source": [ + "### Image reconstruction\n", + "We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data.\n", + "RegularizedIterativeSENSEReconstruction solves the following reconstruction problem:\n", + "\n", + "Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms,\n", + "coil sensitivity maps...) $A$ then we can formulate the forward problem as:\n", + "\n", + "$ y = Ax + n $\n", + "\n", + "where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 $\n", + "\n", + "where $W^\\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal\n", + "operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain\n", + "a solution with certain properties:\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$\n", + "\n", + "where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image.\n", + "With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$.\n", + "\n", + "Setting the derivative of the functional $F$ to zero and rearranging yields\n", + "\n", + "$ (A^H W A + l B) x = A^H W y + l x_{reg}$\n", + "\n", + "which is a linear system $Hx = b$ that needs to be solved for $x$.\n", + "\n", + "One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution\n", + "dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks\n", + "has been used as an image regulariser.\n", + "\n", + "In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image.\n", + "Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using\n", + "only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the\n", + "regularization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4da15c2", + "metadata": {}, + "outputs": [], + "source": [ + "import mrpro" + ] + }, + { + "cell_type": "markdown", + "id": "de055070", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Read-in the raw data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ac1d89f", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.data import KData\n", + "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", + "\n", + "# Load in the Data and the trajectory from the ISMRMRD file\n", + "kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd())\n", + "kdata.header.recon_matrix.x = 256\n", + "kdata.header.recon_matrix.y = 256" + ] + }, + { + "cell_type": "markdown", + "id": "1f389140", + "metadata": {}, + "source": [ + "##### Image $x_{reg}$ from fully sampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212b915c", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction\n", + "from mrpro.data import CsmData\n", + "\n", + "# Estimate coil maps\n", + "direct_reconstruction = DirectReconstruction(kdata, csm=None)\n", + "img_coilwise = direct_reconstruction(kdata)\n", + "csm = CsmData.from_idata_walsh(img_coilwise)\n", + "\n", + "# Iterative SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3)\n", + "img_iterative_sense = iterative_sense_reconstruction(kdata)" + ] + }, + { + "cell_type": "markdown", + "id": "bec6b712", + "metadata": {}, + "source": [ + "##### Image $x$ from undersampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6740447", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Data undersampling, i.e. take only the first 20 radial lines\n", + "idx_us = torch.arange(0, 20)[None, :]\n", + "kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fbfd664", + "metadata": {}, + "outputs": [], + "source": [ + "# Iterativ SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6)\n", + "img_us_iterative_sense = iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041ffe72", + "metadata": {}, + "outputs": [], + "source": [ + "# Regularized iterativ SENSE reconstruction\n", + "from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction\n", + "\n", + "regularization_weight = 1.0\n", + "n_iterations = 6\n", + "regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction(\n", + " kdata_us,\n", + " csm=csm,\n", + " n_iterations=n_iterations,\n", + " regularization_data=img_iterative_sense.data,\n", + " regularization_weight=regularization_weight,\n", + ")\n", + "img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5bbec1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()]\n", + "vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4))\n", + "for ind in range(3):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "2bd49c87", + "metadata": {}, + "source": [ + "### Behind the scenes" + ] + }, + { + "cell_type": "markdown", + "id": "53779251", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", + "\n", + "This is very similar to the iterative SENSE reconstruction. For more detail please look at the\n", + "iterative_sense_reconstruction notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e985a4f3", + "metadata": {}, + "outputs": [], + "source": [ + "dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator()\n", + "fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us)\n", + "csm_operator = csm.as_operator()\n", + "acquisition_operator = fourier_operator @ csm_operator" + ] + }, + { + "cell_type": "markdown", + "id": "2daa0fee", + "metadata": {}, + "source": [ + "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac1d5fb4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "right_hand_side = (\n", + " acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76a0b153", + "metadata": {}, + "source": [ + "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5effb592", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.operators import IdentityOp\n", + "\n", + "operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor(\n", + " regularization_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f24a8588", + "metadata": {}, + "source": [ + "##### Run conjugate gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96827838", + "metadata": {}, + "outputs": [], + "source": [ + "img_manual = mrpro.algorithms.optimizers.cg(\n", + " operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c065a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the reconstructed image\n", + "vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]]\n", + "vis_title = ['Regularized Iterative SENSE R=20', '\"Manual\" Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4))\n", + "for ind in range(2):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "d6d7efdf", + "metadata": {}, + "source": [ + "### Check for equal results\n", + "The two versions should result in the same image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f59b6015", + "metadata": {}, + "outputs": [], + "source": [ + "# If the assert statement did not raise an exception, the results are equal.\n", + "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" + ] + }, + { + "cell_type": "markdown", + "id": "6ecd6e70", + "metadata": {}, + "source": [ + "### Next steps\n", + "Play around with the regularization_weight to see how it effects the final image quality.\n", + "\n", + "Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications\n", + "we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the\n", + "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", + "compared to the unregularised images." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py new file mode 100644 index 000000000..e41dc4ac5 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -0,0 +1,193 @@ +# %% [markdown] +# # Regularized Iterative SENSE Reconstruction of 2D golden angle radial data +# Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data +# %% +# define zenodo URL of the example ismrmd data +zenodo_url = 'https://zenodo.org/records/10854057/files/' +fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' +# %% +# Download raw data +import tempfile + +import requests + +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() + +# %% [markdown] +# ### Image reconstruction +# We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data. +# RegularizedIterativeSENSEReconstruction solves the following reconstruction problem: +# +# Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms, +# coil sensitivity maps...) $A$ then we can formulate the forward problem as: +# +# $ y = Ax + n $ +# +# where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$ +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 $ +# +# where $W^\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal +# operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain +# a solution with certain properties: +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$ +# +# where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image. +# With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$. +# +# Setting the derivative of the functional $F$ to zero and rearranging yields +# +# $ (A^H W A + l B) x = A^H W y + l x_{reg}$ +# +# which is a linear system $Hx = b$ that needs to be solved for $x$. +# +# One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution +# dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks +# has been used as an image regulariser. +# +# In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image. +# Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using +# only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the +# regularization. + +# %% +import mrpro + +# %% [markdown] +# ##### Read-in the raw data +# %% +from mrpro.data import KData +from mrpro.data.traj_calculators import KTrajectoryIsmrmrd + +# Load in the Data and the trajectory from the ISMRMRD file +kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd()) +kdata.header.recon_matrix.x = 256 +kdata.header.recon_matrix.y = 256 + +# %% [markdown] +# ##### Image $x_{reg}$ from fully sampled data + +# %% +from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction +from mrpro.data import CsmData + +# Estimate coil maps +direct_reconstruction = DirectReconstruction(kdata, csm=None) +img_coilwise = direct_reconstruction(kdata) +csm = CsmData.from_idata_walsh(img_coilwise) + +# Iterative SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3) +img_iterative_sense = iterative_sense_reconstruction(kdata) + +# %% [markdown] +# ##### Image $x$ from undersampled data + +# %% +import torch + +# Data undersampling, i.e. take only the first 20 radial lines +idx_us = torch.arange(0, 20)[None, :] +kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition') + +# %% +# Iterativ SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6) +img_us_iterative_sense = iterative_sense_reconstruction(kdata_us) + +# %% +# Regularized iterativ SENSE reconstruction +from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction + +regularization_weight = 1.0 +n_iterations = 6 +regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction( + kdata_us, + csm=csm, + n_iterations=n_iterations, + regularization_data=img_iterative_sense.data, + regularization_weight=regularization_weight, +) +img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us) + +# %% +import matplotlib.pyplot as plt + +vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()] +vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4)) +for ind in range(3): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + + +# %% [markdown] +# ### Behind the scenes + +# %% [markdown] +# ##### Set-up the density compensation operator $W$ and acquisition model $A$ +# +# This is very similar to the iterative SENSE reconstruction. For more detail please look at the +# iterative_sense_reconstruction notebook. +# %% +dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator() +fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us) +csm_operator = csm.as_operator() +acquisition_operator = fourier_operator @ csm_operator + +# %% [markdown] +# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ + +# %% +right_hand_side = ( + acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data +) + + +# %% [markdown] +# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ + +# %% +from mrpro.operators import IdentityOp + +operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor( + regularization_weight +) + +# %% [markdown] +# ##### Run conjugate gradient + +# %% +img_manual = mrpro.algorithms.optimizers.cg( + operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0 +) + +# %% +# Display the reconstructed image +vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]] +vis_title = ['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4)) +for ind in range(2): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + +# %% [markdown] +# ### Check for equal results +# The two versions should result in the same image data. + +# %% +# If the assert statement did not raise an exception, the results are equal. +assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) + +# %% [markdown] +# ### Next steps +# Play around with the regularization_weight to see how it effects the final image quality. +# +# Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications +# we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the +# streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality +# compared to the unregularised images. diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 04d37128c..32785a91a 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -4,20 +4,17 @@ from collections.abc import Callable -import torch - -from mrpro.algorithms.optimizers.cg import cg -from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( + RegularizedIterativeSENSEReconstruction, +) from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData -from mrpro.data.IData import IData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator -class IterativeSENSEReconstruction(DirectReconstruction): +class IterativeSENSEReconstruction(RegularizedIterativeSENSEReconstruction): r"""Iterative SENSE reconstruction. This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2` @@ -49,6 +46,8 @@ def __init__( ) -> None: """Initialize IterativeSENSEReconstruction. + For a regularized version of the iterative SENSE algorithm please see RegularizedIterativeSENSEReconstruction. + Parameters ---------- kdata @@ -74,76 +73,4 @@ def __init__( ValueError If the kdata and fourier_op are None or if csm is a Callable but kdata is None. """ - super().__init__(kdata, fourier_op, csm, noise, dcf) - self.n_iterations = n_iterations - - def _self_adjoint_operator(self) -> LinearOperator: - """Create the self-adjoint operator. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Create the self-adjoint operator as :math:`H = A^H W A` if the DCF is not None otherwise as :math:`H = A^H A`. - """ - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Create H = A^H W A - operator = operator.H @ dcf_operator @ operator - else: - # Create H = A^H A - operator = operator.H @ operator - - return operator - - def _right_hand_side(self, kdata: KData) -> torch.Tensor: - """Calculate the right-hand-side of the normal equation. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Calculate the right-hand-side as :math:`b = A^H W y` if the DCF is not None otherwise as :math:`b = A^H y`. - - Parameters - ---------- - kdata - k-space data to reconstruct. - """ - device = kdata.data.device - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Calculate b = A^H W y - (right_hand_side,) = operator.to(device).H(dcf_operator(kdata.data)[0]) - else: - # Calculate b = A^H y - (right_hand_side,) = operator.to(device).H(kdata.data) - - return right_hand_side - - def forward(self, kdata: KData) -> IData: - """Apply the reconstruction. - - Parameters - ---------- - kdata - k-space data to reconstruct. - - Returns - ------- - the reconstruced image. - """ - device = kdata.data.device - if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) - - operator = self._self_adjoint_operator().to(device) - right_hand_side = self._right_hand_side(kdata) - - img_tensor = cg( - operator, right_hand_side, initial_value=right_hand_side, max_iterations=self.n_iterations, tolerance=0.0 - ) - img = IData.from_tensor_and_kheader(img_tensor, kdata.header) - return img + super().__init__(kdata, fourier_op, csm, noise, dcf, n_iterations=n_iterations, regularization_weight=0) diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index 127be8d5f..c4208157e 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -101,15 +101,13 @@ def direct_reconstruction(self, kdata: KData) -> IData: ------- image data """ - device = kdata.data.device if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) + kdata = prewhiten_kspace(kdata, self.noise) operator = self.fourier_op if self.csm is not None: operator = operator @ self.csm.as_operator() if self.dcf is not None: operator = self.dcf.as_operator() @ operator - operator = operator.to(device) (img_tensor,) = operator.H(kdata.data) img = IData.from_tensor_and_kheader(img_tensor, kdata.header) return img diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py new file mode 100644 index 000000000..c9a307ebe --- /dev/null +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -0,0 +1,139 @@ +"""Regularized Iterative SENSE Reconstruction by adjoint Fourier transform.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace +from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.data._kdata.KData import KData +from mrpro.data.CsmData import CsmData +from mrpro.data.DcfData import DcfData +from mrpro.data.IData import IData +from mrpro.data.KNoise import KNoise +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator + + +class RegularizedIterativeSENSEReconstruction(DirectReconstruction): + r"""Regularized iterative SENSE reconstruction. + + This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2 + + \frac{1}{2}L||Bx - x_0||_2^2` + by using a conjugate gradient algorithm to solve + :math:`H x = b` with :math:`H = A^H W A + L B` and :math:`b = A^H W y + L x_0` where :math:`A` + is the acquisition model (coil sensitivity maps, Fourier operator, k-space sampling), :math:`y` is the acquired + k-space data, :math:`W` describes the density compensation, :math:`L` is the strength of the regularization and + :math:`x_0` is the regularization image (i.e. the prior). :math:`B` is a linear operator applied to :math:`x`. + """ + + n_iterations: int + """Number of CG iterations.""" + + regularization_data: torch.Tensor + """Regularization data (i.e. prior) :math:`x_0`.""" + + regularization_weight: torch.Tensor + """Strength of the regularization :math:`L`.""" + + regularization_op: LinearOperator + """Linear operator :math:`B` applied to the current estimate in the regularization term.""" + + def __init__( + self, + kdata: KData | None = None, + fourier_op: LinearOperator | None = None, + csm: Callable | CsmData | None = CsmData.from_idata_walsh, + noise: KNoise | None = None, + dcf: DcfData | None = None, + *, + n_iterations: int = 5, + regularization_data: float | torch.Tensor = 0.0, + regularization_weight: float | torch.Tensor, + regularization_op: LinearOperator | None = None, + ) -> None: + """Initialize RegularizedIterativeSENSEReconstruction. + + For a unregularized version of the iterative SENSE algorithm the regularization_weight can be set to 0 or + IterativeSENSEReconstruction algorithm can be used. + + Parameters + ---------- + kdata + KData. If kdata is provided and fourier_op or dcf are None, then fourier_op and dcf are estimated based on + kdata. Otherwise fourier_op and dcf are used as provided. + fourier_op + Instance of the FourierOperator used for reconstruction. If None, set up based on kdata. + csm + Sensitivity maps for coil combination. If None, no coil combination is carried out, i.e. images for each + coil are returned. If a callable is provided, coil images are reconstructed using the adjoint of the + FourierOperator (including density compensation) and then sensitivity maps are calculated using the + callable. For this, kdata needs also to be provided. For examples have a look at the CsmData class + e.g. from_idata_walsh or from_idata_inati. + noise + KNoise used for prewhitening. If None, no prewhitening is performed + dcf + K-space sampling density compensation. If None, set up based on kdata. + n_iterations + Number of CG iterations + regularization_data + Regularization data, e.g. a reference image (:math:`x_0`). + regularization_weight + Strength of the regularization (:math:`L`). + regularization_op + Linear operator :math:`B` applied to the current estimate in the regularization term. If None, nothing is + applied to the current estimate. + + + Raises + ------ + ValueError + If the kdata and fourier_op are None or if csm is a Callable but kdata is None. + """ + super().__init__(kdata, fourier_op, csm, noise, dcf) + self.n_iterations = n_iterations + self.regularization_data = torch.as_tensor(regularization_data) + self.regularization_weight = torch.as_tensor(regularization_weight) + self.regularization_op = regularization_op if regularization_op is not None else IdentityOp() + + def forward(self, kdata: KData) -> IData: + """Apply the reconstruction. + + Parameters + ---------- + kdata + k-space data to reconstruct. + + Returns + ------- + the reconstruced image. + """ + if self.noise is not None: + kdata = prewhiten_kspace(kdata, self.noise) + + # Create the normal operator as H = A^H W A if the DCF is not None otherwise as H = A^H A. + # The acquisition model is A = F S if the CSM S is defined otherwise A = F with the Fourier operator F + csm_op = self.csm.as_operator() if self.csm is not None else IdentityOp() + precondition_op = self.dcf.as_operator() if self.dcf is not None else IdentityOp() + operator = (self.fourier_op @ csm_op).H @ precondition_op @ (self.fourier_op @ csm_op) + + # Calculate the right-hand-side as b = A^H W y if the DCF is not None otherwise as b = A^H y. + (right_hand_side,) = (self.fourier_op @ csm_op).H(precondition_op(kdata.data)[0]) + + # Add regularization + if not torch.all(self.regularization_weight == 0): + operator = operator + IdentityOp() @ (self.regularization_weight * self.regularization_op) + right_hand_side += self.regularization_weight * self.regularization_data + + img_tensor = cg( + operator, + right_hand_side, + initial_value=right_hand_side, + max_iterations=self.n_iterations, + tolerance=0.0, + ) + img = IData.from_tensor_and_kheader(img_tensor, kdata.header) + return img diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 180186dd4..38b539f8b 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,8 +1,10 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import RegularizedIterativeSENSEReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction __all__ = [ "DirectReconstruction", "IterativeSENSEReconstruction", - "Reconstruction" -] \ No newline at end of file + "Reconstruction", + "RegularizedIterativeSENSEReconstruction" +] From 87cdaf1fcda5c537405c4fd73a363e0bde706a7f Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Tue, 29 Oct 2024 10:56:32 +0100 Subject: [PATCH 06/23] Release 0.241029 (#491) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0a67c464f..0f6ae6fb6 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241015 +0.241029 From 8b73303fb31779e10909d0ccccfa9fd64122fb36 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Wed, 30 Oct 2024 11:33:11 +0100 Subject: [PATCH 07/23] Fix unrolled CG autograd (#492) --- pyproject.toml | 1 + src/mrpro/algorithms/optimizers/cg.py | 11 +++-------- tests/algorithms/test_cg.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d2780589..31798a35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", + "ignore:Anomaly Detection has been enabled:UserWarning", #torch.autograd ] addopts = "-n auto" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 54cb1d782..b879796b0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -81,11 +81,6 @@ def cg( if torch.vdot(residual.flatten(), residual.flatten()) == 0: return solution - # squared tolerance; - # (we will check ||residual||^2 < tolerance^2 instead of ||residual|| < tol - # to avoid the computation of the root for the norm) - tolerance_squared = tolerance**2 - # dummy value. new value will be set in loop before first usage residual_norm_squared_previous = None @@ -95,7 +90,7 @@ def cg( residual_norm_squared = torch.vdot(residual_flat, residual_flat).real # check if the solution is already accurate enough - if tolerance != 0 and (residual_norm_squared < tolerance_squared): + if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution if iteration > 0: @@ -105,8 +100,8 @@ def cg( # update estimates of the solution and the residual (operator_conjugate_vector,) = operator(conjugate_vector) alpha = residual_norm_squared / (torch.vdot(conjugate_vector.flatten(), operator_conjugate_vector.flatten())) - solution += alpha * conjugate_vector - residual -= alpha * operator_conjugate_vector + solution = solution + alpha * conjugate_vector + residual = residual - alpha * operator_conjugate_vector residual_norm_squared_previous = residual_norm_squared diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 16c16e78b..8a4434e2a 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -145,3 +145,13 @@ def callback(cg_status: CGStatus) -> None: assert True cg(h_operator, right_hand_side, callback=callback) + + +def test_autograd(system): + """Test autograd through cg""" + h_operator, right_hand_side, _ = system + right_hand_side.requires_grad_(True) + with torch.autograd.detect_anomaly(): + result = cg(h_operator, right_hand_side, tolerance=0, max_iterations=5) + result.abs().sum().backward() + assert right_hand_side.grad is not None From eeb923354ae00a1ee728fc20be8d8495afa0f468 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Sat, 2 Nov 2024 19:29:46 +0100 Subject: [PATCH 08/23] Add matrix of LinearOperators (#432) --- src/mrpro/operators/LinearOperator.py | 16 +- src/mrpro/operators/LinearOperatorMatrix.py | 363 +++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_linearoperatormatrix.py | 322 ++++++++++++++++ tests/operators/test_operator_norm.py | 5 +- 5 files changed, 704 insertions(+), 4 deletions(-) create mode 100644 src/mrpro/operators/LinearOperatorMatrix.py create mode 100644 tests/operators/test_linearoperatormatrix.py diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index f136d9d51..d919e63c6 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -102,7 +102,7 @@ def operator_norm( max_iterations: int = 20, relative_tolerance: float = 1e-4, absolute_tolerance: float = 1e-5, - callback: Callable | None = None, + callback: Callable[[torch.Tensor], None] | None = None, ) -> torch.Tensor: """Power iteration for computing the operator norm of the linear operator. @@ -294,6 +294,20 @@ def __rmul__(self, other: torch.Tensor | complex) -> LinearOperator: else: return NotImplemented # type: ignore[unreachable] + def __and__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Vertical stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self], [other]] + return mrpro.operators.LinearOperatorMatrix(operators) + + def __or__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Horizontal stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self, other]] + return mrpro.operators.LinearOperatorMatrix(operators) + @property def gram(self) -> LinearOperator: """Gram operator. diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py new file mode 100644 index 000000000..ab0673398 --- /dev/null +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -0,0 +1,363 @@ +"""Linear Operator Matrix class.""" + +from __future__ import annotations + +import operator +from collections.abc import Callable, Iterator, Sequence +from functools import reduce +from types import EllipsisType +from typing import cast + +import torch +from typing_extensions import Self, Unpack + +from mrpro.operators.LinearOperator import LinearOperator, LinearOperatorSum +from mrpro.operators.Operator import Operator +from mrpro.operators.ZeroOp import ZeroOp + +_SingleIdxType = int | slice | EllipsisType | Sequence[int] +_IdxType = _SingleIdxType | tuple[_SingleIdxType, _SingleIdxType] + + +class LinearOperatorMatrix(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): + r"""Matrix of Linear Operators. + + A matrix of Linear Operators, where each element is a Linear Operator. + + This matrix can be applied to a sequence of tensors, where the number of tensors should match + the number of columns of the matrix. The output will be a sequence of tensors, where the number + of tensors will match the number of rows of the matrix. + The i-th output tensor is calculated as + :math:`\sum_j \text{operators}[i][j](x[j])` where :math:`\text{operators}[i][j]` is the linear operator + in the i-th row and j-th column and :math:`x[j]` is the j-th input tensor. + + The matrix can be indexed and sliced like a regular matrix to get submatrices. + If indexing returns a single element, it is returned as a Linear Operator. + + Basic arithmetic operations are supported with Linear Operators and Tensors. + + """ + + _operators: list[list[LinearOperator]] + + def __init__(self, operators: Sequence[Sequence[LinearOperator]]): + """Initialize Linear Operator Matrix. + + Create a matrix of LinearOperators from a sequence of rows, where each row is a sequence + of LinearOperators that represent the columns of the matrix. + + Parameters + ---------- + operators + A sequence of rows, which are sequences of Linear Operators. + """ + if not all(isinstance(op, LinearOperator) for row in operators for op in row): + raise ValueError('All elements should be LinearOperators.') + if not all(len(row) == len(operators[0]) for row in operators): + raise ValueError('All rows should have the same length.') + super().__init__() + self._operators = cast( # cast because ModuleList is not recognized as a list + list[list[LinearOperator]], torch.nn.ModuleList(torch.nn.ModuleList(row) for row in operators) + ) + self._shape = (len(operators), len(operators[0]) if operators else 0) + + @property + def shape(self) -> tuple[int, int]: + """Shape of the Operator Matrix (rows, columns).""" + return self._shape + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has columns. + + Returns + ------- + Output tensors. The same number of tensors as the operator has rows. + """ + if len(x) != self.shape[1]: + raise ValueError('Input should be the same number of tensors as the LinearOperatorMatrix has columns.') + return tuple( + reduce(operator.add, (op(xi)[0] for op, xi in zip(row, x, strict=True))) for row in self._operators + ) + + def __getitem__(self, idx: _IdxType) -> Self | LinearOperator: + """Index the Operator Matrix. + + Parameters + ---------- + idx + Index or slice to select rows and columns. + + Returns + ------- + Subset LinearOperatorMatrix or Linear Operator. + """ + idxs: tuple[_SingleIdxType, _SingleIdxType] = idx if isinstance(idx, tuple) else (idx, slice(None)) + if len(idxs) > 2: + raise IndexError('Too many indices for LinearOperatorMatrix') + + def _to_numeric_index(idx: slice | int | Sequence[int] | EllipsisType, length: int) -> Sequence[int]: + """Convert index to a sequence of integers or raise an error.""" + if isinstance(idx, slice): + if (idx.start is not None and (idx.start < -length or idx.start >= length)) or ( + idx.stop is not None and (idx.stop < -length or idx.stop > length) + ): + raise IndexError('Index out of range') + return range(*idx.indices(length)) + if isinstance(idx, int): + if idx < -length or idx >= length: + raise IndexError('Index out of range') + return (idx,) + if idx is Ellipsis: + return range(length) + if isinstance(idx, Sequence): + if min(idx) < -length or max(idx) >= length: + raise IndexError('Index out of range') + return idx + else: + raise IndexError('Invalid index type') + + row_numbers = _to_numeric_index(idxs[0], self._shape[0]) + col_numbers = _to_numeric_index(idxs[1], self._shape[1]) + + sliced_operators = [ + [row[col_number] for col_number in col_numbers] + for row in [self._operators[row_number] for row_number in row_numbers] + ] + + # Return a single operator if only one row and column is selected + if len(row_numbers) == 1 and len(col_numbers) == 1: + return sliced_operators[0][0] + else: + return self.__class__(sliced_operators) + + def __iter__(self) -> Iterator[Sequence[LinearOperator]]: + """Iterate over the rows of the Operator Matrix.""" + return iter(self._operators) + + def __repr__(self): + """Representation of the Operator Matrix.""" + return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' + + # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. + def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + """Addition.""" + operators: list[list[LinearOperator]] = [] + if isinstance(other, LinearOperatorMatrix): + if self.shape != other.shape: + raise ValueError('OperatorMatrix shapes do not match.') + for self_row, other_row in zip(self._operators, other._operators, strict=False): + operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) + elif isinstance(other, LinearOperator | torch.Tensor): + if not self.shape[0] == self.shape[1]: + raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') + for i, self_row in enumerate(self._operators): + operators.append([op + other if i == j else op for j, op in enumerate(self_row)]) + else: + return NotImplemented # type: ignore[unreachable] + return self.__class__(operators) + + def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + """Right addition.""" + return self.__add__(other) + + def __mul__(self, other: torch.Tensor | Sequence[torch.Tensor | complex] | complex) -> Self: + """LinearOperatorMatrix*Tensor multiplication. + + Example: ([A,B]*c)(x) = [A*c, B*c](x) = A(c*x) + B(c*x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[1] + elif len(other) != self.shape[1]: + raise ValueError('Other should have the same length as the operator has columns.') + else: + other_ = other + operators = [] + for row in self._operators: + operators.append([op * o for op, o in zip(row, other_, strict=True)]) + return self.__class__(operators) + + def __rmul__(self, other: torch.Tensor | Sequence[torch.Tensor] | complex) -> Self: + """Tensor*LinearOperatorMatrix multiplication. + + Example: (c*[A,B])(x) = [c*A, c*B](x) = c*A(x) + c*B(x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[0] + elif len(other) != self.shape[0]: + raise ValueError('Other should have the same length as the operator has rows.') + else: + other_ = other + operators = [] + for row, o in zip(self._operators, other_, strict=True): + operators.append([cast(LinearOperator, o * op) for op in row]) + return self.__class__(operators) + + def __matmul__(self, other: LinearOperator | Self) -> Self: # type: ignore[override] + """Composition of operators.""" + if isinstance(other, LinearOperator): + return self._binary_operation(other, '__matmul__') + elif isinstance(other, LinearOperatorMatrix): + if self.shape[1] != other.shape[0]: + raise ValueError('OperatorMatrix shapes do not match.') + new_operators = [] + for row in self._operators: + new_row = [] + for other_col in zip(*other._operators, strict=True): + elements = [s @ o for s, o in zip(row, other_col, strict=True)] + new_row.append(LinearOperatorSum(*elements)) + new_operators.append(new_row) + return self.__class__(new_operators) + return NotImplemented # type: ignore[unreachable] + + @property + def H(self) -> Self: # noqa N802 + """Adjoints of the operators.""" + return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the adjoint of the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has rows. + + Returns + ------- + Output tensors. The same number of tensors as the operator has columns. + """ + return self.H(*x) + + @classmethod + def from_diagonal(cls, *operators: LinearOperator): + """Create a diagonal LinearOperatorMatrix. + + Create a square LinearOperatorMatrix with the given Linear Operators on the diagonal, + resulting in a block-diagonal linear operator. + + Parameters + ---------- + operators + Sequence of Linear Operators to be placed on the diagonal. + """ + operator_matrix: list[list[LinearOperator]] = [ + [op if i == j else ZeroOp(False) for j in range(len(operators))] for i, op in enumerate(operators) + ] + return cls(operator_matrix) + + def operator_norm( + self, + *initial_value: torch.Tensor, + dim: Sequence[int] | None = None, + max_iterations: int = 20, + relative_tolerance: float = 1e-4, + absolute_tolerance: float = 1e-5, + callback: Callable[[torch.Tensor], None] | None = None, + ) -> torch.Tensor: + """Upper bound of operator norm of the Matrix. + + Uses the bounds :math:`||[A, B]^T|||<=sqrt(||A||^2 + ||B||^2)` and :math:`||[A, B]|||<=max(||A||,||B||)` + to estimate the operator norm of the matrix. + First, operator_norm is called on each element of the matrix. + Next, the norm is estimated for each column using the first bound. + Finally, the norm of the full matrix of linear operators is calculated using the second bound. + + Parameters + ---------- + initial_value + Initial value(s) for the power iteration, length should match the number of columns + of the operator matrix. + dim + Dimensions to calculate the operator norm over. Other dimensions are assumed to be + batch dimensions. None means all dimensions. + max_iterations + Maximum number of iterations used in the power iteration. + relative_tolerance + Relative tolerance for convergence. + absolute_tolerance + Absolute tolerance for convergence. + callback + Callback function to be called with the current estimate of the operator norm. + + + Returns + ------- + Estimated operator norm upper bound. + """ + + def _singlenorm(op: LinearOperator, initial_value: torch.Tensor): + return op.operator_norm( + initial_value, + dim=dim, + max_iterations=max_iterations, + relative_tolerance=relative_tolerance, + absolute_tolerance=absolute_tolerance, + callback=callback, + ) + + if len(initial_value) != self.shape[1]: + raise ValueError('Initial value should have the same length as the operator has columns.') + norms = torch.tensor( + [[_singlenorm(op, iv) for op, iv in zip(row, initial_value, strict=True)] for row in self._operators] + ) + norm = norms.square().sum(-2).sqrt().amax(-1).unsqueeze(-1) + return norm + + def __or__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Vertical stacking.""" + if isinstance(other, LinearOperator): + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking : cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[*self._operators[0], other]] + return self.__class__(operators) + else: + if (rows_self := self.shape[0]) != (rows_other := other.shape[0]): + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack matrices with {rows_self} and {rows_other}.' + ) + operators = [[*self_row, *other_row] for self_row, other_row in zip(self, other, strict=True)] + return self.__class__(operators) + + def __ror__(self, other: LinearOperator) -> Self: + """Vertical stacking.""" + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[other, *self._operators[0]]] + return self.__class__(operators) + + def __and__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Horizontal stacking.""" + if isinstance(other, LinearOperator): + if cols := self.shape[1] > 1: + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [*self._operators, [other]] + return self.__class__(operators) + else: + if (cols_self := self.shape[1]) != (cols_other := other.shape[1]): + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack matrices with {cols_self} and {cols_other} columns.' + ) + operators = [*self._operators, *other] + return self.__class__(operators) + + def __rand__(self, other: LinearOperator) -> Self: + """Horizontal stacking.""" + if cols := self.shape[1] > 1: + raise ValueError( + f'Shape mismatch in horizontal stacking: cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [[other], *self._operators] + return self.__class__(operators) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 345c337fa..b8c16ebfe 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -11,6 +11,7 @@ from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp from mrpro.operators.PhaseOp import PhaseOp @@ -36,6 +37,7 @@ "GridSamplingOp", "IdentityOp", "LinearOperator", + "LinearOperatorMatrix", "MagnitudeOp", "MultiIdentityOp", "Operator", diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py new file mode 100644 index 000000000..7ba87d715 --- /dev/null +++ b/tests/operators/test_linearoperatormatrix.py @@ -0,0 +1,322 @@ +from typing import Any + +import pytest +import torch +from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +def random_linearop(size, rng): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex64_tensor(size), '... i j, ... j -> ... i') + + +def random_linearoperatormatrix(size, inner_size, rng): + """Create a random LinearOperatorMatrix.""" + operators = [[random_linearop(inner_size, rng) for i in range(size[1])] for j in range(size[0])] + return LinearOperatorMatrix(operators) + + +def test_linearoperatormatrix_shape(): + """Test creation and shape of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert matrix.shape == (5, 3) + + +def test_linearoperatormatrix_add_matrix(): + """Test addition of two LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 + matrix2)(*vector) + expected = tuple(a + b for a, b in zip(matrix1(*vector), matrix2(*vector), strict=False)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_add_tensor_nonsquare(): + """Test failure of addition of tensor to non-square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + with pytest.raises(NotImplementedError, match='square'): + (matrix1 + other) + + +def test_linearoperatormatrix_add_tensor_square(): + """Add tensor to square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 3), (2, 2), rng) + other = rng.complex64_tensor(2) + vector = rng.complex64_tensor((3, 2)) + result = (matrix1 + other)(*vector) + expected = tuple((mv + other * v for mv, v in zip(matrix1(*vector), vector, strict=True))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_rmul(): + """Test post multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + vector = rng.complex64_tensor((3, 10)) + result = (other * matrix1)(*vector) + expected = tuple(other * el for el in matrix1(*vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_mul(): + """Test pre multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(10) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 * other)(*vector) + expected = matrix1(*(other * el for el in vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition(): + """Test composition of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 @ matrix2)(*vector) + expected = matrix1(*(matrix2(*vector))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition_mismatch(): + """Test composition with mismatching shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((4, 3), (3, 10), rng) + vector = rng.complex64_tensor((4, 10)) + with pytest.raises(ValueError, match='shapes do not match'): + (matrix1 @ matrix2)(*vector) + + +def test_linearoperatormatrix_adjoint(): + """Test adjointness of Adjoint.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + dotproduct_adjointness_test(Wrapper(), rng.complex64_tensor((3, 10)), rng.complex64_tensor((5, 3))) + + +def test_linearoperatormatrix_repr(): + """Test repr of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert 'LinearOperatorMatrix(shape=(5, 3)' in repr(matrix) + + +def test_linearoperatormatrix_getitem(): + """Test slicing of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + def check(actual, expected): + assert tuple(tuple(row) for row in actual) == tuple(tuple(row) for row in expected) + + sliced = matrix[1:3, 2] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[2:3] for row in matrix._operators[1:3]]) + + sliced = matrix[0] + assert sliced.shape == (1, 6) + check(sliced._operators, matrix._operators[:1]) + + sliced = matrix[..., 0] + assert sliced.shape == (12, 1) + check(sliced._operators, [row[:1] for row in matrix._operators]) + + sliced = matrix[1:6:2, (3, 4)] + assert sliced.shape == (3, 2) + check(sliced._operators, [[matrix._operators[i][j] for j in (3, 4)] for i in range(1, 6, 2)]) + + sliced = matrix[-2:-4:-1, -1] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[-1:] for row in matrix._operators[-2:-4:-1]]) + + sliced = matrix[5, 5] + assert isinstance(sliced, LinearOperator) + assert sliced == matrix._operators[5][5] + + +def test_linearoperatormatrix_getitem_error(): + """Test error when slicing with wrong indices.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + with pytest.raises(IndexError, match='Too many indices'): + matrix[1, 1, 1] + + with pytest.raises(IndexError, match='out of range'): + matrix[20] + with pytest.raises(IndexError, match='out of range'): + matrix[-20] + with pytest.raises(IndexError, match='out of range'): + matrix[1:100] + with pytest.raises(IndexError, match='out of range'): + matrix[(100, 1)] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., -20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 1:100] + with pytest.raises(IndexError, match='index type'): + matrix[..., 1.0] + + +def test_linearoperatormatrix_norm_rows(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((3, 1), (3, 10), rng) + vector = rng.complex64_tensor((1, 10)) + result = matrix.operator_norm(*vector) + expected = sum(row[0].operator_norm(vector[0], dim=None) ** 2 for row in matrix._operators) ** 0.5 + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_norm_cols(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((1, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = matrix.operator_norm(*vector) + expected = max(op.operator_norm(v, dim=None) for op, v in zip(matrix._operators[0], vector, strict=False)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parametrize('seed', [0, 1, 2, 3]) +def test_linearoperatormatrix_norm(seed): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(seed) + matrix = random_linearoperatormatrix((4, 2), (3, 10), rng) + vector = rng.complex64_tensor((2, 10)) + result = matrix.operator_norm(*vector) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + real = Wrapper().operator_norm(vector, dim=None) + + assert result >= real + + +def test_linearoperatormatrix_shorthand_vertical(): + """Test shorthand for vertical stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 10), rng) + op2 = random_linearop((4, 10), rng) + x1 = rng.complex64_tensor((10,)) + + matrix1 = op1 & op2 + assert matrix1.shape == (2, 1) + + actual = matrix1(x1) + expected = (*op1(x1), *op2(x1)) + torch.testing.assert_close(actual, expected) + + matrix2 = op2 & (matrix1 & op1) + assert matrix2.shape == (4, 1) + + matrix3 = matrix2 & matrix2 + assert matrix3.shape == (8, 1) + + actual = matrix3(x1) + expected = 2 * (*op2(x1), *matrix1(x1), *op1(x1)) + torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_shorthand_horizontal(): + """Test shorthand for horizontal stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 4), rng) + op2 = random_linearop((3, 2), rng) + x1 = rng.complex64_tensor((4,)) + x2 = rng.complex64_tensor((2,)) + x3 = rng.complex64_tensor((4,)) + x4 = rng.complex64_tensor((2,)) + + matrix1 = op1 | op2 + assert matrix1.shape == (1, 2) + + actual1 = matrix1(x1, x2) + expected1 = (op1(x1)[0] + op2(x2)[0],) + torch.testing.assert_close(actual1, expected1) + + matrix2 = op2 | (matrix1 | op1) + assert matrix2.shape == (1, 4) + + matrix3 = matrix2 | matrix2 + assert matrix3.shape == (1, 8) + + expected3 = (2 * (op2(x2)[0] + (matrix1(x3, x4)[0] + op1(x1)[0])),) + actual3 = matrix3(x2, x3, x4, x1, x2, x3, x4, x1) + torch.testing.assert_close(actual3, expected3) + + +def test_linearoperatormatrix_stacking_error(): + """Test error when stacking matrix operators with different shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 4), (3, 10), rng) + matrix2 = random_linearoperatormatrix((3, 2), (3, 10), rng) + matrix3 = random_linearoperatormatrix((2, 4), (3, 10), rng) + op = random_linearop((3, 10), rng) + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & matrix2 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | matrix3 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | op + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & op + + +def test_linearoperatormatrix_error_nonlinearop(): + """Test error if trying to create a LinearOperatorMatrix with non linear operator.""" + op: Any = [[MagnitudeOp()]] # Any is used to hide this error from mypy + with pytest.raises(ValueError, match='LinearOperator'): + LinearOperatorMatrix(op) + + +def test_linearoperatormatrix_error_inconsistent_shapes(): + """Test error if trying to create a LinearOperatorMatrix with inonsistent row lengths.""" + rng = RandomGenerator(0) + op = random_linearop((3, 4), rng) + with pytest.raises(ValueError, match='same length'): + LinearOperatorMatrix([[op, op], [op]]) + + +def test_linearoperatormatrix_from_diagonal(): + """Test creation of LinearOperatorMatrix from diagonal.""" + rng = RandomGenerator(0) + ops = [random_linearop((2, 4), rng) for _ in range(3)] + matrix = LinearOperatorMatrix.from_diagonal(*ops) + xs = rng.complex64_tensor((3, 4)) + actual = matrix(*xs) + expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) + torch.testing.assert_close(actual, expected) diff --git a/tests/operators/test_operator_norm.py b/tests/operators/test_operator_norm.py index d5ae39873..c6e13dcf9 100644 --- a/tests/operators/test_operator_norm.py +++ b/tests/operators/test_operator_norm.py @@ -12,9 +12,8 @@ def test_power_iteration_uses_stopping_criterion(): """Test if the power iteration stops if the absolute and relative tolerance are chosen high.""" - # callback function that should not be called because the power iteration - # should stop if the tolerances are set high - def callback(): + def callback(_): + """Callback function that should not be called, because the power iteration should stop.""" pytest.fail('The power iteration did not stop despite high atol and rtol!') random_generator = RandomGenerator(seed=0) From bc830f5c345bfcc123a627e31395c77060854787 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:14:13 +0100 Subject: [PATCH 09/23] [pre-commit] pre-commit autoupdate (#498) Co-authored-by: Patrick Schuenke <patrick.schuenke@gmail.com> --- .pre-commit-config.yaml | 8 +++---- examples/direct_reconstruction.ipynb | 8 +++---- examples/direct_reconstruction.py | 8 +++---- examples/iterative_sense_reconstruction.ipynb | 8 +++---- examples/iterative_sense_reconstruction.py | 8 +++---- examples/pulseq_2d_radial_golden_angle.ipynb | 22 ++++++++++--------- examples/pulseq_2d_radial_golden_angle.py | 17 +++++++------- ...rized_iterative_sense_reconstruction.ipynb | 8 +++---- ...ularized_iterative_sense_reconstruction.py | 8 +++---- 9 files changed, 49 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7803441eb..95d14317a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-docstring-first @@ -15,14 +15,14 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.25.0 + rev: v1.27.0 hooks: - id: typos @@ -34,7 +34,7 @@ repos: exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy pass_filenames: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 3b6dc930e..1e4e74c9c 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 7672aa7e7..5d55812c9 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index 87249b2fb..f612d7522 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index 6d0bc49a5..ba5e6a01a 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index 52e0310bb..bcb4482a1 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,13 +33,14 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" ] }, { @@ -50,9 +51,10 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { @@ -125,10 +127,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", - "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - "seq_file.write(response.content)\n", - "seq_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", + " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + " seq_file.write(response.content)\n", + " seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 3f857c382..f4db5217a 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,13 +19,14 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') + # %% # Download raw data using requests -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -62,10 +63,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') -response = requests.get(zenodo_url + seq_fname, timeout=30) -seq_file.write(response.content) -seq_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: + response = requests.get(zenodo_url + seq_fname, timeout=30) + seq_file.write(response.content) + seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb index 6b1c2704b..0a6743161 100644 --- a/examples/regularized_iterative_sense_reconstruction.ipynb +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py index e41dc4ac5..2ab7ba033 100644 --- a/examples/regularized_iterative_sense_reconstruction.py +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction From dcb5e3407a32e2cbd9085ffbe4ce9430f8bacad0 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Sat, 9 Nov 2024 15:26:16 +0100 Subject: [PATCH 10/23] Add PCA-based compression operator (#181) Co-authored-by: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> --- src/mrpro/operators/PCACompressionOp.py | 85 ++++++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_pca_compression_op.py | 50 +++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/mrpro/operators/PCACompressionOp.py create mode 100644 tests/operators/test_pca_compression_op.py diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index b8c16ebfe..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -14,6 +14,7 @@ from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -41,6 +42,7 @@ "MagnitudeOp", "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..a600bcecb --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,50 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) From c268ad25a429f93d3bb33df6cb2bf0ffc06c4759 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Sat, 9 Nov 2024 16:11:57 +0100 Subject: [PATCH 11/23] Fix CartesianSamplingOp (#483) --- src/mrpro/operators/CartesianSamplingOp.py | 10 ++++++---- tests/operators/test_cartesian_sampling_op.py | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..7a51924b1 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -47,26 +47,28 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') self.register_buffer('_fft_idx', kidx) # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..7caa13e7c 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,6 +2,7 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match(): 'regular_undersampling', 'random_undersampling', 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', ], ) def test_cart_sampling_op_fwd_adj(sampling): @@ -70,8 +74,8 @@ def test_cart_sampling_op_fwd_adj(sampling): nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) sx = 'uf' - sy = 'uf' - sz = 'uf' + sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' + sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() # Subsample data and trajectory @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling): for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = torch.randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') From 54674a95b59165cbf5cb21e1a71d1bb616921554 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Mon, 11 Nov 2024 01:55:28 +0100 Subject: [PATCH 12/23] Add reduce_view (#476) Undoes expand, i.e. replaces stride 0 dimensions by size=1 dimensions --- src/mrpro/utils/__init__.py | 3 ++- src/mrpro/utils/reshape.py | 32 ++++++++++++++++++++++++++++++++ tests/utils/test_reshape.py | 32 +++++++++++++++++++++++++++++--- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index c09071f4b..b16fae37a 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -5,10 +5,11 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view __all__ = [ "broadcast_right", "modify_acq_info", + "reduce_view", "remove_repeat", "slice_profiles", "smap", diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 39d12e51f..31d495afd 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,5 +1,7 @@ """Tensor reshaping utilities.""" +from collections.abc import Sequence + import torch @@ -67,3 +69,33 @@ def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]: max_dim = max(el.ndim for el in x) unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x)) return unsqueezed + + +def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor: + """Reduce expanded dimensions in a view to singletons. + + Reduce either all or specific dimensions to a singleton if it + points to the same memory address. + This undoes expand. + + Parameters + ---------- + x + input tensor + dim + only reduce expanded dimensions in the specified dimensions. + If None, reduce all expanded dimensions. + """ + if dim is None: + dim_: Sequence[int] = range(x.ndim) + elif isinstance(dim, Sequence): + dim_ = [d % x.ndim for d in dim] + else: + dim_ = [dim % x.ndim] + + stride = x.stride() + newsize = [ + 1 if stride == 0 and d in dim_ else oldsize + for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) + ] + return torch.as_strided(x, newsize, stride) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index 60a0dc5e3..dd57b8feb 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,9 @@ """Tests for reshaping utilities.""" import torch -from mrpro.utils import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right + +from tests import RandomGenerator def test_broadcast_right(): @@ -12,7 +14,7 @@ def test_broadcast_right(): def test_unsqueeze_left(): - """Test unsqueeze left""" + """Test unsqueeze_left""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_left(tensor, 2) assert unsqueezed.shape == (1, 1, 1, 2, 3) @@ -20,8 +22,32 @@ def test_unsqueeze_left(): def test_unsqueeze_right(): - """Test unsqueeze right""" + """Test unsqueeze_right""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_right(tensor, 2) assert unsqueezed.shape == (1, 2, 3, 1, 1) assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_reduce_view(): + """Test reduce_view""" + + tensor = RandomGenerator(0).float32_tensor((1, 2, 3, 1, 1, 1)) + tensor = tensor.expand(1, 2, 3, 4, 1, 1).contiguous() # this cannot be removed + tensor = tensor.expand(7, 2, 3, 4, 5, 6) + + reduced_all = reduce_view(tensor) + assert reduced_all.shape == (1, 2, 3, 4, 1, 1) + assert torch.equal(reduced_all.expand_as(tensor), tensor) + + reduced_two = reduce_view(tensor, (0, -1)) + assert reduced_two.shape == (1, 2, 3, 4, 5, 1) + assert torch.equal(reduced_two.expand_as(tensor), tensor) + + reduced_one_neg = reduce_view(tensor, -1) + assert reduced_one_neg.shape == (7, 2, 3, 4, 5, 1) + assert torch.equal(reduced_one_neg.expand_as(tensor), tensor) + + reduced_one_pos = reduce_view(tensor, 0) + assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) + assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) From f0f91c3ff91296a07e6bb7a184684f6674a91795 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Mon, 11 Nov 2024 11:35:57 +0100 Subject: [PATCH 13/23] Add apply_ to dataclasses (#505) Applies a function to all children of a dataclass --- src/mrpro/data/AcqInfo.py | 14 ++---- src/mrpro/data/MoveDataMixin.py | 68 ++++++++++++++++++++++------ src/mrpro/data/SpatialDimension.py | 55 +++++----------------- tests/data/test_movedatamixin.py | 19 ++++++++ tests/data/test_spatial_dimension.py | 23 ---------- 5 files changed, 89 insertions(+), 90 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 83f752a57..a66224de1 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,6 +1,6 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass import ismrmrd @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return ( - SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion) - ) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -251,10 +247,10 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..8d977d0a6 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import ClassVar, TypeAlias +from typing import ClassVar, TypeAlias, cast import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -219,26 +223,62 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) return new + def apply_( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index b5f3dfd27..46b1db89a 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -3,14 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from copy import deepcopy from dataclasses import dataclass from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike -from typing_extensions import Any, Protocol, TypeVar, overload +from typing_extensions import Protocol, Self, TypeVar, overload import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin @@ -109,6 +108,16 @@ def from_array_zyx( return SpatialDimension(z, y, x) + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply_(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" @@ -134,48 +143,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z in-place. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - if func is not None: - self.z = func(self.z) - self.y = func(self.y) - self.x = func(self.x) - return self - - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - - def func_(x: Any) -> T_co: # noqa: ANN401 - if isinstance(x, torch.Tensor): - # use clone for autograd - x = x.clone() - else: - x = deepcopy(x) - if func is None: - return x - else: - return func(x) - - return self.__class__(func_(self.z), func_(self.y), func_(self.x)) - - def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: - """Return a deep copy of the SpatialDimension.""" - return self.apply() - @overload def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 06f55a4dc..3feb091de 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index cd46854b9..afafece04 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) -def test_spatial_dimension_apply(): - """Test apply (out of place)""" - - def conversion(x: torch.Tensor) -> torch.Tensor: - assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' - return x.swapaxes(0, 1).square() - - xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) - spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion) - - assert spatial_dimension_outofplace is not spatial_dimension - - assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) - - x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_outofplace.x, x) - assert torch.equal(spatial_dimension_outofplace.y, y) - assert torch.equal(spatial_dimension_outofplace.z, z) - - def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) From a96b9c61f2c667f439124f779e399f374a75c646 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:01:46 +0100 Subject: [PATCH 14/23] Fix installation from TestPyPi in workflow (#499) --- .github/workflows/deployment.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index dd900496e..a55b7860d 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -94,7 +94,15 @@ jobs: run: | VERSION=${{ needs.build-testpypi-package.outputs.version }} SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }} - python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ + for i in {1..3}; do + if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then + echo "Package installed successfully." + break + else + echo "Attempt $i failed. Retrying in 10 seconds..." + sleep 10 + fi + done build-pypi-package: name: Build Package for PyPI From 191ab06b35ecae890febb34d0d2841b8fe2cd4c5 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Mon, 11 Nov 2024 17:45:56 +0100 Subject: [PATCH 15/23] Remove check-docstring-first pre-commit hook (#508) --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95d14317a..4790095f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ repos: rev: v5.0.0 hooks: - id: check-added-large-files - - id: check-docstring-first - id: check-merge-conflict - id: check-yaml - id: check-toml From 6c54e31ac22e80cfbafad88102b4ca1ab30d8241 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Tue, 12 Nov 2024 00:27:11 +0100 Subject: [PATCH 16/23] Revert NamedTemporaryFile ContextManager in example notebooks (#500) --- .pre-commit-config.yaml | 16 +++++++------- examples/direct_reconstruction.ipynb | 8 +++---- examples/direct_reconstruction.py | 8 +++---- examples/iterative_sense_reconstruction.ipynb | 8 +++---- examples/iterative_sense_reconstruction.py | 8 +++---- examples/pulseq_2d_radial_golden_angle.ipynb | 22 +++++++++---------- examples/pulseq_2d_radial_golden_angle.py | 17 +++++++------- ...rized_iterative_sense_reconstruction.ipynb | 8 +++---- ...ularized_iterative_sense_reconstruction.py | 8 +++---- examples/ruff.toml | 1 + 10 files changed, 51 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4790095f9..303fd43fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,11 +53,11 @@ repos: - "--extra-index-url=https://pypi.python.org/simple" ci: - autofix_commit_msg: | - [pre-commit] auto fixes from pre-commit hooks - autofix_prs: false - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [mypy] - submodules: false + autofix_commit_msg: | + [pre-commit] auto fixes from pre-commit hooks + autofix_prs: false + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate" + autoupdate_schedule: monthly + skip: [mypy] + submodules: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 1e4e74c9c..3b6dc930e 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 5d55812c9..7672aa7e7 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index f612d7522..87249b2fb 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index ba5e6a01a..6d0bc49a5 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index bcb4482a1..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,14 +33,13 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" ] }, { @@ -51,10 +50,9 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { @@ -127,10 +125,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", - " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - " seq_file.write(response.content)\n", - " seq_file.flush()" + "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", + "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + "seq_file.write(response.content)\n", + "seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index f4db5217a..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,14 +19,13 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' - +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') # %% # Download raw data using requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -63,10 +62,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: - response = requests.get(zenodo_url + seq_fname, timeout=30) - seq_file.write(response.content) - seq_file.flush() +seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') +response = requests.get(zenodo_url + seq_fname, timeout=30) +seq_file.write(response.content) +seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb index 0a6743161..6b1c2704b 100644 --- a/examples/regularized_iterative_sense_reconstruction.ipynb +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py index 2ab7ba033..e41dc4ac5 100644 --- a/examples/regularized_iterative_sense_reconstruction.py +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/ruff.toml b/examples/ruff.toml index 11a1e6167..1bb114755 100644 --- a/examples/ruff.toml +++ b/examples/ruff.toml @@ -5,4 +5,5 @@ lint.extend-ignore = [ "T20", #print "E402", #module-import-not-at-top-of-file "S101", #assert + "SIM115", #context manager for opening files ] From a89df61455675201b82e86c19b4b8b9743a5068b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Tue, 12 Nov 2024 08:58:11 +0100 Subject: [PATCH 17/23] Adapt KHeader parameters (#506) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/mrpro/data/AcqInfo.py | 41 ++++----- src/mrpro/data/KHeader.py | 48 ++++------ src/mrpro/data/TrajectoryDescription.py | 29 ------ src/mrpro/data/__init__.py | 4 +- src/mrpro/data/_kdata/KData.py | 15 ++-- src/mrpro/data/_kdata/KDataRearrangeMixin.py | 10 +-- src/mrpro/data/_kdata/KDataSelectMixin.py | 9 +- src/mrpro/data/_kdata/KDataSplitMixin.py | 28 +++--- src/mrpro/utils/__init__.py | 7 +- src/mrpro/utils/modify_acq_info.py | 35 -------- src/mrpro/utils/unit_conversion.py | 94 ++++++++++++++++++++ tests/conftest.py | 6 +- tests/data/test_kdata.py | 34 +++++-- tests/utils/test_modify_acq_info.py | 18 ---- tests/utils/test_unit_conversion.py | 82 +++++++++++++++++ 15 files changed, 278 insertions(+), 182 deletions(-) delete mode 100644 src/mrpro/data/TrajectoryDescription.py delete mode 100644 src/mrpro/utils/modify_acq_info.py create mode 100644 src/mrpro/utils/unit_conversion.py delete mode 100644 tests/utils/test_modify_acq_info.py create mode 100644 tests/utils/test_unit_conversion.py diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index a66224de1..f5d677f97 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -6,23 +6,24 @@ import ismrmrd import numpy as np import torch -from typing_extensions import Self, TypeVar +from einops import rearrange +from typing_extensions import Self from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.unit_conversion import mm_to_m -# Conversion functions for units -T = TypeVar('T', float, torch.Tensor) +def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object: + """Change the shape of the fields in AcqInfo.""" + if isinstance(field, Rotation): + return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths)) -def ms_to_s(ms: T) -> T: - """Convert ms to s.""" - return ms / 1000 + if isinstance(field, torch.Tensor): + return rearrange(field, pattern, **axes_lengths) - -def mm_to_m(m: T) -> T: - """Convert mm to m.""" - return m / 1000 + return field @dataclass(slots=True) @@ -121,30 +122,24 @@ class AcqInfo(MoveDataMixin): number_of_samples: torch.Tensor """Number of sample points per readout (readouts may have different number of sample points).""" + orientation: Rotation + """Rotation describing the orientation of the readout, phase and slice encoding direction.""" + patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - phase_dir: SpatialDimension[torch.Tensor] - """Directional cosine of phase encoding (2D).""" - physiology_time_stamp: torch.Tensor """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] """Center of the excited volume, in LPS coordinates relative to isocenter [m].""" - read_dir: SpatialDimension[torch.Tensor] - """Directional cosine of readout/frequency encoding.""" - sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - slice_dir: SpatialDimension[torch.Tensor] - """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" @@ -247,14 +242,16 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), + orientation=Rotation.from_directions( + spatialdimension_2d(headers['slice_dir']), + spatialdimension_2d(headers['phase_dir']), + spatialdimension_2d(headers['read_dir']), + ), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), - phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), - read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - slice_dir=spatialdimension_2d(headers['slice_dir']), trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above user_float=tensor_2d(headers['user_float']), user_int=tensor_2d(headers['user_int']), diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 488c1d9cd..dea12e28b 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -13,12 +13,12 @@ from typing_extensions import Self from mrpro.data import enums -from mrpro.data.AcqInfo import AcqInfo, mm_to_m, ms_to_s +from mrpro.data.AcqInfo import AcqInfo from mrpro.data.EncodingLimits import EncodingLimits from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues +from mrpro.utils.unit_conversion import mm_to_m, ms_to_s if TYPE_CHECKING: # avoid circular imports by importing only when type checking @@ -40,9 +40,6 @@ class KHeader(MoveDataMixin): trajectory: KTrajectoryCalculator """Function to calculate the k-space trajectory.""" - b0: float - """Magnetic field strength [T].""" - encoding_limits: EncodingLimits """K-space encoding limits.""" @@ -61,12 +58,9 @@ class KHeader(MoveDataMixin): acq_info: AcqInfo """Information of the acquisitions (i.e. readout lines).""" - h1_freq: float + lamor_frequency_proton: float """Lamor frequency of hydrogen nuclei [Hz].""" - n_coils: int | None = None - """Number of receiver coils.""" - datetime: datetime.datetime | None = None """Date and time of acquisition.""" @@ -88,7 +82,7 @@ class KHeader(MoveDataMixin): echo_train_length: int = 1 """Number of echoes in a multi-echo acquisition.""" - seq_type: str = UNKNOWN + sequence_type: str = UNKNOWN """Type of sequence.""" model: str = UNKNOWN @@ -100,16 +94,13 @@ class KHeader(MoveDataMixin): protocol_name: str = UNKNOWN """Name of the acquisition protocol.""" - misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! - """Dictionary with miscellaneous parameters.""" - calibration_mode: enums.CalibrationMode = enums.CalibrationMode.OTHER """Mode of how calibration data is acquired. """ interleave_dim: enums.InterleavingDimension = enums.InterleavingDimension.OTHER """Interleaving dimension.""" - traj_type: enums.TrajectoryType = enums.TrajectoryType.OTHER + trajectory_type: enums.TrajectoryType = enums.TrajectoryType.OTHER """Type of trajectory.""" measurement_id: str = UNKNOWN @@ -118,8 +109,9 @@ class KHeader(MoveDataMixin): patient_name: str = UNKNOWN """Name of the patient.""" - trajectory_description: TrajectoryDescription = dataclasses.field(default_factory=TrajectoryDescription) - """Description of the trajectory.""" + _misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! + """Dictionary with miscellaneous parameters. These parameters are for information purposes only. Reconstruction + algorithms should not rely on them.""" @property def fa_degree(self) -> torch.Tensor | None: @@ -160,17 +152,14 @@ def from_ismrmrd( enc: ismrmrdschema.encodingType = header.encoding[encoding_number] # These are guaranteed to exist - parameters = {'h1_freq': header.experimentalConditions.H1resonanceFrequency_Hz, 'acq_info': acq_info} + parameters = { + 'lamor_frequency_proton': header.experimentalConditions.H1resonanceFrequency_Hz, + 'acq_info': acq_info, + } if defaults is not None: parameters.update(defaults) - if ( - header.acquisitionSystemInformation is not None - and header.acquisitionSystemInformation.receiverChannels is not None - ): - parameters['n_coils'] = header.acquisitionSystemInformation.receiverChannels - if header.sequenceParameters is not None: if header.sequenceParameters.TR: parameters['tr'] = ms_to_s(torch.as_tensor(header.sequenceParameters.TR)) @@ -184,7 +173,7 @@ def from_ismrmrd( parameters['echo_spacing'] = ms_to_s(torch.as_tensor(header.sequenceParameters.echo_spacing)) if header.sequenceParameters.sequence_type is not None: - parameters['seq_type'] = header.sequenceParameters.sequence_type + parameters['sequence_type'] = header.sequenceParameters.sequence_type if enc.reconSpace is not None: parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm).apply_(mm_to_m) @@ -212,7 +201,7 @@ def from_ismrmrd( ) if enc.trajectory is not None: - parameters['traj_type'] = enums.TrajectoryType(enc.trajectory.value) + parameters['trajectory_type'] = enums.TrajectoryType(enc.trajectory.value) # Either use the series or study time if available if header.measurementInformation is not None and header.measurementInformation.seriesTime is not None: @@ -245,15 +234,8 @@ def from_ismrmrd( if header.acquisitionSystemInformation.systemModel is not None: parameters['model'] = header.acquisitionSystemInformation.systemModel - if header.acquisitionSystemInformation.systemFieldStrength_T is not None: - parameters['b0'] = header.acquisitionSystemInformation.systemFieldStrength_T - - # estimate b0 from h1_freq if not given - if 'b0' not in parameters: - parameters['b0'] = parameters['h1_freq'] / 4258e4 - # Dump everything into misc - parameters['misc'] = dataclasses.asdict(header) + parameters['_misc'] = dataclasses.asdict(header) if overwrite is not None: parameters.update(overwrite) diff --git a/src/mrpro/data/TrajectoryDescription.py b/src/mrpro/data/TrajectoryDescription.py deleted file mode 100644 index 801811005..000000000 --- a/src/mrpro/data/TrajectoryDescription.py +++ /dev/null @@ -1,29 +0,0 @@ -"""TrajectoryDescription dataclass.""" - -import dataclasses -from dataclasses import dataclass - -from ismrmrd.xsd.ismrmrdschema.ismrmrd import trajectoryDescriptionType -from typing_extensions import Self - - -@dataclass(slots=True) -class TrajectoryDescription: - """TrajectoryDescription dataclass.""" - - identifier: str = '' - user_parameter_long: dict[str, int] = dataclasses.field(default_factory=dict) - user_parameter_double: dict[str, float] = dataclasses.field(default_factory=dict) - user_parameter_string: dict[str, str] = dataclasses.field(default_factory=dict) - comment: str = '' - - @classmethod - def from_ismrmrd(cls, trajectory_description: trajectoryDescriptionType) -> Self: - """Create TrajectoryDescription from ismrmrd traj description.""" - return cls( - user_parameter_long={p.name: int(p.value) for p in trajectory_description.userParameterLong}, - user_parameter_double={p.name: float(p.value) for p in trajectory_description.userParameterDouble}, - user_parameter_string={p.name: str(p.value) for p in trajectory_description.userParameterString}, - comment=trajectory_description.comment or '', - identifier=trajectory_description.identifier or '', - ) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index b5034d668..d5667a5bc 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -16,7 +16,6 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription __all__ = [ "AcqIdx", "AcqInfo", @@ -37,8 +36,7 @@ "QHeader", "Rotation", "SpatialDimension", - "TrajectoryDescription", "acq_filters", "enums", "traj_calculators" -] \ No newline at end of file +] diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 409e8aac9..57af617bc 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -18,15 +18,15 @@ from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin from mrpro.data.acq_filters import is_image_acquisition -from mrpro.data.AcqInfo import AcqInfo +from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd -from mrpro.utils import modify_acq_info KDIM_SORT_LABELS = ( 'k1', @@ -200,10 +200,13 @@ def from_file( sort_idx = np.lexsort(acq_indices) # torch does not have lexsort as of pytorch 2.2 (March 2024) # Finally, reshape and sort the tensors in acqinfo and acqinfo.idx, and kdata. - def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): - return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) - - kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2 + ) + if isinstance(field, torch.Tensor | Rotation) + else field + ) kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py index 05bab7681..23a58dea6 100644 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/KDataRearrangeMixin.py @@ -6,8 +6,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import AcqInfo -from mrpro.utils import modify_acq_info +from mrpro.data.AcqInfo import rearrange_acq_info_fields class KDataRearrangeMixin(_KDataProtocol): @@ -35,9 +34,8 @@ def rearrange_k2_k1_into_k1(self: Self) -> Self: kheader = copy.deepcopy(self.header) # Update shape of acquisition info index - def reshape_acq_info(info: AcqInfo): - return rearrange(info, 'other k2 k1 ... -> other 1 (k2 k1) ...') - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py index fb0e02aa1..8f8a452cf 100644 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/KDataSelectMixin.py @@ -7,7 +7,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation class KDataSelectMixin(_KDataProtocol): @@ -51,10 +51,9 @@ def select_other_subset( other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) # Adapt header - def select_acq_info(info: torch.Tensor): - return info[other_idx, ...] - - kheader.acq_info = modify_acq_info(select_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) # Select data kdat = self.data[other_idx, ...] diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py index d2c641125..c28004af4 100644 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/KDataSplitMixin.py @@ -1,15 +1,17 @@ """Mixin class to split KData into other subsets.""" -import copy -from typing import Literal +from typing import Literal, TypeVar, cast import torch from einops import rearrange, repeat from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation + +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) class KDataSplitMixin(_KDataProtocol): @@ -56,8 +58,9 @@ def _split_k2_or_k1_into_other( def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, :, split_idx, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, :, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' @@ -69,8 +72,8 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, split_idx, :, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' @@ -93,13 +96,14 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) # Create new header with correct shape - kheader = copy.deepcopy(self.header) + kheader = self.header.clone() # Update shape of acquisition info index - def reshape_acq_info(info: torch.Tensor): - return rearrange(split_acq_info(info), rearrange_pattern_acq_info) - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) # Update other label limits and acquisition info setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index b16fae37a..6cd18c2cc 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,21 +1,22 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing +import mrpro.utils.unit_conversion from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop -from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view + __all__ = [ "broadcast_right", - "modify_acq_info", "reduce_view", "remove_repeat", "slice_profiles", "smap", "split_idx", "typing", + "unit_conversion", "unsqueeze_left", "unsqueeze_right", "zero_pad_or_crop" -] \ No newline at end of file +] diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py deleted file mode 100644 index d535e53c5..000000000 --- a/src/mrpro/utils/modify_acq_info.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Modify AcqInfo.""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import TYPE_CHECKING - -import torch - -if TYPE_CHECKING: - from mrpro.data.AcqInfo import AcqInfo - - -def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: - """Go through all fields of AcqInfo object and apply changes. - - Parameters - ---------- - fun_modify - Function which takes AcqInfo fields as input and returns modified AcqInfo field - acq_info - AcqInfo object - """ - # Apply function to all fields of acq_info - for field in dataclasses.fields(acq_info): - current = getattr(acq_info, field.name) - if isinstance(current, torch.Tensor): - setattr(acq_info, field.name, fun_modify(current)) - elif dataclasses.is_dataclass(current): - for subfield in dataclasses.fields(current): - subcurrent = getattr(current, subfield.name) - setattr(current, subfield.name, fun_modify(subcurrent)) - - return acq_info diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py new file mode 100644 index 000000000..0115bed47 --- /dev/null +++ b/src/mrpro/utils/unit_conversion.py @@ -0,0 +1,94 @@ +"""Conversion between different units.""" + +from typing import TypeVar + +import numpy as np +import torch + +__all__ = [ + 'ms_to_s', + 's_to_ms', + 'mm_to_m', + 'm_to_mm', + 'deg_to_rad', + 'rad_to_deg', + 'lamor_frequency_to_magnetic_field', + 'magnetic_field_to_lamor_frequency', + 'GYROMAGNETIC_RATIO_PROTON', +] + +GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 +r"""The gyromagnetic ratio :math:`\frac{\gamma}{2\pi}` of 1H in H20 in Hz/T""" + +# Conversion functions for units +T = TypeVar('T', float, torch.Tensor) + + +def ms_to_s(ms: T) -> T: + """Convert ms to s.""" + return ms / 1000 + + +def s_to_ms(s: T) -> T: + """Convert s to ms.""" + return s * 1000 + + +def mm_to_m(mm: T) -> T: + """Convert mm to m.""" + return mm / 1000 + + +def m_to_mm(m: T) -> T: + """Convert m to mm.""" + return m * 1000 + + +def deg_to_rad(deg: T) -> T: + """Convert degree to radians.""" + if isinstance(deg, torch.Tensor): + return torch.deg2rad(deg) + return deg / 180.0 * np.pi + + +def rad_to_deg(deg: T) -> T: + """Convert radians to degree.""" + if isinstance(deg, torch.Tensor): + return torch.rad2deg(deg) + return deg * 180.0 / np.pi + + +def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: + """Convert the Lamor frequency [Hz] to the magntic field strength [T]. + + Parameters + ---------- + lamor_frequency + Lamor frequency [Hz] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Magnetic field strength [T] + """ + return lamor_frequency / gyromagnetic_ratio + + +def magnetic_field_to_lamor_frequency( + magnetic_field_strength: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON +) -> T: + """Convert the magntic field strength [T] to Lamor frequency [Hz]. + + Parameters + ---------- + magnetic_field_strength + Strength of the magnetic field [T] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Lamor frequency [Hz] + """ + return magnetic_field_strength * gyromagnetic_ratio diff --git a/tests/conftest.py b/tests/conftest.py index e3f943462..899e8959c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,9 +45,9 @@ def generate_random_acquisition_properties(generator: RandomGenerator): 'encoding_space_ref': generator.uint16(), 'sample_time_us': generator.float32(), 'position': generator.float32_tuple(3), - 'read_dir': generator.float32_tuple(3), - 'phase_dir': generator.float32_tuple(3), - 'slice_dir': generator.float32_tuple(3), + 'read_dir': (1, 0, 0), # read, phase and slice have to form rotation + 'phase_dir': (0, 1, 0), + 'slice_dir': (0, 0, 1), 'patient_table_position': generator.float32_tuple(3), 'idx': ismrmrd.EncodingCounters(**idx_properties), 'user_int': generator.uint32_tuple(8), diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 822c63045..b7ec1fa7d 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -2,12 +2,13 @@ import pytest import torch -from einops import rearrange, repeat +from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp -from mrpro.utils import modify_acq_info, split_idx +from mrpro.utils import split_idx from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData @@ -77,10 +78,11 @@ def consistently_shaped_kdata(request, random_kheader_shape): # Start with header kheader, n_other, n_coils, n_k2, n_k1, n_k0 = random_kheader_shape - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) # Create kdata with consistent shape kdata = generate_random_data(RandomGenerator(request.param['seed']), (n_other, n_coils, n_k2, n_k1, n_k0)) @@ -162,7 +164,7 @@ def test_KData_kspace(ismrmrd_cart): assert relative_image_difference(reconstructed_img[0, 0, 0, ...], ismrmrd_cart.img_ref) <= 0.05 -@pytest.mark.parametrize(('field', 'value'), [('b0', 11.3), ('tr', torch.tensor([24.3]))]) +@pytest.mark.parametrize(('field', 'value'), [('lamor_frequency_proton', 42.88 * 1e6), ('tr', torch.tensor([24.3]))]) def test_KData_modify_header(ismrmrd_cart, field, value): """Overwrite some parameters in the header.""" parameter_dict = {field: value} @@ -469,3 +471,21 @@ def test_KData_remove_readout_os(monkeypatch, random_kheader): # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. assert relative_image_difference(torch.abs(img_recon), img_tensor[:, 0, ...]) <= 0.05 + + +def test_modify_acq_info(random_kheader_shape): + """Test the modification of the acquisition info.""" + # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] + kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape + + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) + + # Verify shape + assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) + assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_modify_acq_info.py b/tests/utils/test_modify_acq_info.py deleted file mode 100644 index 451303d02..000000000 --- a/tests/utils/test_modify_acq_info.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for modification of acquisition infos.""" - -from einops import rearrange -from mrpro.utils import modify_acq_info - - -def test_modify_acq_info(random_kheader_shape): - """Test the modification of the acquisition info.""" - # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] - kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape - - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) - - # Verify shape - assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py new file mode 100644 index 000000000..a232a5366 --- /dev/null +++ b/tests/utils/test_unit_conversion.py @@ -0,0 +1,82 @@ +"""Tests of unit conversion.""" + +import numpy as np +import torch +from mrpro.utils.unit_conversion import ( + deg_to_rad, + lamor_frequency_to_magnetic_field, + m_to_mm, + magnetic_field_to_lamor_frequency, + mm_to_m, + ms_to_s, + rad_to_deg, + s_to_ms, +) + +from tests import RandomGenerator + + +def test_mm_to_m(): + """Verify mm to m conversion.""" + generator = RandomGenerator(seed=0) + mm_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(mm_to_m(mm_input), mm_input / 1000.0) + + +def test_m_to_mm(): + """Verify m to mm conversion.""" + generator = RandomGenerator(seed=0) + m_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(m_to_mm(m_input), m_input * 1000.0) + + +def test_ms_to_s(): + """Verify ms to s conversion.""" + generator = RandomGenerator(seed=0) + ms_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(ms_to_s(ms_input), ms_input / 1000.0) + + +def test_s_to_ms(): + """Verify s to ms conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(s_to_ms(s_input), s_input * 1000.0) + + +def test_rad_to_deg_tensor(): + """Verify radians to degree conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(rad_to_deg(s_input), torch.rad2deg(s_input)) + + +def test_deg_to_rad_tensor(): + """Verify degree to radians conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(deg_to_rad(s_input), torch.deg2rad(s_input)) + + +def test_rad_to_deg_float(): + """Verify radians to degree conversion.""" + assert rad_to_deg(np.pi / 2) == 90.0 + + +def test_deg_to_rad_float(): + """Verify degree to radians conversion.""" + assert deg_to_rad(180.0) == np.pi + + +def test_lamor_frequency_to_magnetic_field(): + """Verify conversion of lamor frequency to magnetic field.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + proton_lamor_frequency_at_3tesla = 127.74 * 1e6 + assert lamor_frequency_to_magnetic_field(proton_lamor_frequency_at_3tesla, proton_gyromagnetic_ratio) == 3.0 + + +def test_magnetic_field_to_lamor_frequency(): + """Verify conversion of magnetic field to lamor frequency.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + magnetic_field_strength = 3.0 + assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6 From 84b983cda06f3408c3b255866eb5c1087f7f2e48 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Tue, 12 Nov 2024 10:51:20 +0100 Subject: [PATCH 18/23] Release v0.241112 (#510) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0f6ae6fb6..c60039027 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241029 +0.241112 From 762fcd777adca6427751a2d00015748d653416b4 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Tue, 12 Nov 2024 12:29:54 +0100 Subject: [PATCH 19/23] Select k-space data based on n_coils (#309) Co-authored-by: Felix F Zimmermann <fzimmermann89@gmail.com> --- src/mrpro/data/_kdata/KData.py | 25 +++++++++++++++++- src/mrpro/data/acq_filters.py | 17 ++++++++++++ src/mrpro/utils/__init__.py | 1 + tests/data/_IsmrmrdRawTestData.py | 28 +++++++++++++------- tests/data/test_kdata.py | 44 ++++++++++++++++++++++++++++++- 5 files changed, 104 insertions(+), 11 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 57af617bc..d43fc49cd 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -17,7 +17,7 @@ from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin -from mrpro.data.acq_filters import is_image_acquisition +from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader @@ -110,6 +110,29 @@ def from_file( modification_time = datetime.datetime.fromtimestamp(mtime) acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)] + + # we need the same number of receiver coils for all acquisitions + n_coils_available = {acq.data.shape[0] for acq in acquisitions} + if len(n_coils_available) > 1: + if ( + ismrmrd_header.acquisitionSystemInformation is not None + and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None + ): + n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels) + else: + # most likely, highest number of elements are the coils used for imaging + n_coils = int(max(n_coils_available)) + + warnings.warn( + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' + 'Data with {n_coils} receiver coil elements will be used.', + stacklevel=1, + ) + acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] + + if not acquisitions: + raise ValueError('No acquisitions meeting the given filter criteria were found.') + kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) diff --git a/src/mrpro/data/acq_filters.py b/src/mrpro/data/acq_filters.py index d64c4d9a6..4723d3bba 100644 --- a/src/mrpro/data/acq_filters.py +++ b/src/mrpro/data/acq_filters.py @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool: """ coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING return coil_calibration_flag.value & acquisition.flags + + +def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool: + """Test if acquisitions was obtained with a certain number of receiver coils. + + Parameters + ---------- + n_coils + number of receiver coils + acquisition + ISMRMRD acquisition + + Returns + ------- + True if the acquisition was obtained with n_coils receiver coils + """ + return acquisition.data.shape[0] == n_coils diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 6cd18c2cc..80ef9d398 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,6 +6,7 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +import mrpro.utils.unit_conversion __all__ = [ "broadcast_right", diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 59c42be15..181be56d6 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -67,6 +67,7 @@ def __init__( trajectory_type: Literal['cartesian', 'radial'] = 'cartesian', sampling_order: Literal['linear', 'low_high', 'high_low', 'random'] = 'linear', phantom: EllipsePhantom | None = None, + add_bodycoil_acquisitions: bool = False, n_separate_calibration_lines: int = 0, ): if not phantom: @@ -222,23 +223,32 @@ def __init__( acq.phase_dir[1] = 1.0 acq.slice_dir[2] = 1.0 - # Initialize an acquisition counter - counter = 0 + scan_counter = 0 # Write out a few noise scans for _ in range(32): noise = self.noise_level * torch.randn(self.n_coils, n_freq_encoding, dtype=torch.complex64) # here's where we would make the noise correlated - acq.scan_counter = counter + acq.scan_counter = scan_counter acq.clearAllFlags() acq.setFlag(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) acq.data[:] = noise.numpy() dataset.append_acquisition(acq) - counter += 1 # increment the scan counter + scan_counter += 1 + + # Add acquisitions obtained with a 2-element body coil (e.g. used for adjustment scans) + if add_bodycoil_acquisitions: + acq.resize(n_freq_encoding, 2, trajectory_dimensions=2) + for _ in range(8): + acq.scan_counter = scan_counter + acq.clearAllFlags() + acq.data[:] = torch.randn(2, n_freq_encoding, dtype=torch.complex64) + dataset.append_acquisition(acq) + scan_counter += 1 + acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) # Calibration lines if n_separate_calibration_lines > 0: - # we take calibration lines around the k-space center traj_ky_calibration, traj_kx_calibration, kpe_calibration = self._cartesian_trajectory( n_separate_calibration_lines, n_freq_encoding, @@ -253,7 +263,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe_calibration): # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -264,7 +274,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_calibration[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Loop over the repetitions, add noise and write to disk for rep in range(self.repetitions): @@ -275,7 +285,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe[rep]): if not self.flag_invalid_reps or rep == 0 or pe_idx < len(kpe[rep]) // 2: # fewer lines for rep > 0 # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -298,7 +308,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_with_noise[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Clean up dataset.close() diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index b7ec1fa7d..ab4f5aabb 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -4,7 +4,7 @@ import torch from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension -from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp @@ -29,6 +29,20 @@ def ismrmrd_cart(ellipse_phantom, tmp_path_factory): return ismrmrd_kdata +@pytest.fixture(scope='session') +def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set with bodycoil and surface coil data.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + add_bodycoil_acquisitions=True, + ) + return ismrmrd_kdata + + @pytest.fixture(scope='session') def ismrmrd_cart_with_calibration_lines(ellipse_phantom, tmp_path_factory): """Undersampled Cartesian data set with calibration lines.""" @@ -126,6 +140,34 @@ def test_KData_raise_wrong_trajectory_shape(ismrmrd_cart): _ = KData.from_file(ismrmrd_cart.filename, trajectory) +def test_KData_raise_warning_for_bodycoil(ismrmrd_cart_bodycoil_and_surface_coil): + """Mix of bodycoil and surface coil acquisitions leads to warning.""" + with pytest.raises(UserWarning, match='Acquisitions with different number'): + _ = KData.from_file(ismrmrd_cart_bodycoil_and_surface_coil.filename, DummyTrajectory()) + + +@pytest.mark.filterwarnings('ignore:Acquisitions with different number:UserWarning') +def test_KData_select_bodycoil_via_filter(ismrmrd_cart_bodycoil_and_surface_coil): + """Bodycoil can be selected via a custom acquisition filter.""" + # This is the recommended way of selecting the body coil (i.e. 2 receiver elements) + kdata = KData.from_file( + ismrmrd_cart_bodycoil_and_surface_coil.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + assert kdata.data.shape[-4] == 2 + + +def test_KData_raise_wrong_coil_number(ismrmrd_cart): + """Wrong number of coils leads to empty acquisitions.""" + with pytest.raises(ValueError, match='No acquisitions meeting the given filter criteria were found'): + _ = KData.from_file( + ismrmrd_cart.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + + def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): """Multiple repetitions with different number of phase encoding lines.""" with pytest.warns(UserWarning, match=r'different number'): From 455679547c70e55aff86c40c0b1ee1ab98742e90 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Tue, 12 Nov 2024 13:01:04 +0100 Subject: [PATCH 20/23] Fix formatting in warning string (#514) --- src/mrpro/data/_kdata/KData.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index d43fc49cd..aaf430497 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -124,8 +124,8 @@ def from_file( n_coils = int(max(n_coils_available)) warnings.warn( - f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' - 'Data with {n_coils} receiver coil elements will be used.', + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. ' + f'Data with {n_coils} receiver coil elements will be used.', stacklevel=1, ) acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] From a8214130948447f27e1a1c540b366f865acfae5d Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann <fzimmermann89@gmail.com> Date: Tue, 12 Nov 2024 13:05:32 +0100 Subject: [PATCH 21/23] Add human readable ids to test (#485) --- tests/algorithms/test_cg.py | 3 +- tests/conftest.py | 300 +++++++++--------- tests/data/test_rotation.py | 8 +- tests/data/test_trajectory.py | 20 +- tests/operators/functionals/conftest.py | 6 +- tests/operators/models/conftest.py | 20 +- tests/operators/test_cartesian_sampling_op.py | 16 +- tests/operators/test_fourier_op.py | 36 ++- tests/operators/test_rearrangeop.py | 1 + 9 files changed, 217 insertions(+), 193 deletions(-) diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 8a4434e2a..4abda2a98 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -16,7 +16,8 @@ (1, 32, False), (4, 32, True), (4, 32, False), - ] + ], + ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'], ) def system(request): """Generate data for creating a system Hx=b with linear and self-adjoint diff --git a/tests/conftest.py b/tests/conftest.py index 899e8959c..1fb7fb95f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,16 +233,16 @@ def create_uniform_traj(nk, k_shape): return k -def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): +def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Create trajectory with random entries.""" random_generator = RandomGenerator(seed=0) k_list = [] - for spacing, nk in zip([sz, sy, sx], [nkz, nky, nkx], strict=True): - if spacing == 'nuf': + for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): + if spacing == 'non-uniform': k = random_generator.float32_tensor(size=nk) - elif spacing == 'uf': + elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) - elif spacing == 'z': + elif spacing == 'zero': k = torch.zeros(nk) k_list.append(k) trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) @@ -250,161 +250,163 @@ def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ - # (0) 2d cart mri with 1 coil, no oversampling - ( - (1, 1, 1, 96, 128), # img shape - (1, 1, 1, 96, 128), # k shape - (1, 1, 1, 128), # kx - (1, 1, 96, 1), # ky - (1, 1, 1, 1), # kz - 'uf', # kx is uniform - 'uf', # ky is uniform - 'z', # zero so no Fourier transform is performed along that dimension - 'uf', # k0 is uniform - 'uf', # k1 is uniform - 'z', # k2 is singleton + ( # (0) 2d Cartesian single coil, no oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 96, 128), # k_shape + (1, 1, 1, 128), # nkx + (1, 1, 96, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (1) 2d cart mri with 1 coil, with oversampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 128, 192), - (1, 1, 1, 192), - (1, 1, 128, 1), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (1) 2d Cartesian single coil, with oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 128, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 128, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (2) 2d non-Cartesian mri with 2 coils - ( - (1, 2, 1, 96, 128), - (1, 2, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 1, 1), - 'nuf', # kx is non-uniform - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (2) 2d non-Cartesian mri with 2 coils + (1, 2, 1, 96, 128), # im_shape + (1, 2, 1, 16, 192), # k_shape + (1, 1, 16, 192), # nkx + (1, 1, 16, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (3) 2d cart mri with irregular sampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'z', - 'z', + ( # (3) 2d Cartesian with irregular sampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (4) 2d single shot spiral - ( - (1, 2, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'z', - 'z', + ( # (4) 2d single shot spiral + (1, 2, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (5) 3d nuFFT mri, 4 coils, 2 other - ( - (2, 4, 16, 32, 64), - (2, 4, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (5) 3d non-uniform, 4 coils, 2 other + (2, 4, 16, 32, 64), # im_shape + (2, 4, 16, 32, 64), # k_shape + (2, 16, 32, 64), # nkx + (2, 16, 32, 64), # nky + (2, 16, 32, 64), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (6) 2d nuFFT cine mri with 8 cardiac phases, 5 coils - ( - (8, 5, 1, 64, 64), - (8, 5, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (6) 2d non-uniform cine with 8 cardiac phases, 5 coils + (8, 5, 1, 64, 64), # im_shape + (8, 5, 1, 18, 128), # k_shape + (8, 1, 18, 128), # nkx + (8, 1, 18, 128), # nky + (8, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (7) 2d cart cine mri with 9 cardiac phases, 6 coils - ( - (9, 6, 1, 96, 128), - (9, 6, 1, 128, 192), - (9, 1, 1, 192), - (9, 1, 128, 1), - (9, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (7) 2d cartesian cine with 9 cardiac phases, 6 coils + (9, 6, 1, 96, 128), # im_shape + (9, 6, 1, 128, 192), # k_shape + (9, 1, 1, 192), # nkx + (9, 1, 128, 1), # nky + (9, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and nuFFT directions - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'uf', - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', + ( # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and non-uniform directions + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (9) radial phase encoding (RPE) , 8 coils with non-Cartesian sampling along readout - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (9) radial phase encoding (RPE), 8 coils with non-Cartesian sampling along readout + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and nuFFT directions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', - 'uf', + ( # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and non-uniform directions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 1, 18, 64), # nky + (5, 96, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'uniform', # type_k2 ), ], + ids=[ + '2d_cartesian_1_coil_no_oversampling', + '2d_cartesian_1_coil_with_oversampling', + '2d_non_cartesian_mri_2_coils', + '2d_cartesian_irregular_sampling', + '2d_single_shot_spiral', + '3d_nonuniform_4_coils_2_other', + '2d_nnonuniform_cine_mri_8_cardiac_phases_5_coils', + '2d_cartesian_cine_9_cardiac_phases_6_coils', + 'radial_phase_encoding_8_coils_with_oversampling', + 'radial_phase_encoding_8_coils_non_cartesian_sampling', + 'stack_of_stars_5_other_3_coil_with_oversampling', + ], ) diff --git a/tests/data/test_rotation.py b/tests/data/test_rotation.py index 6b4cbb52c..035a2d6c6 100644 --- a/tests/data/test_rotation.py +++ b/tests/data/test_rotation.py @@ -535,7 +535,7 @@ def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: assert torch.all(rms < rms_max) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -555,7 +555,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-15, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -576,7 +576,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-16, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -598,7 +598,7 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): torch.testing.assert_close(mat_expected, mat_estimated) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 1baf4340b..1061a93be 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -147,16 +147,16 @@ def test_trajectory_cpu(cartesian_grid): @COMMON_MR_TRAJECTORIES -def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the kz, ky and kz dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_kzyx[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) @@ -171,16 +171,16 @@ def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, @COMMON_MR_TRAJECTORIES -def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the k2, k1 and k0 dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_k210[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) diff --git a/tests/operators/functionals/conftest.py b/tests/operators/functionals/conftest.py index 9d2bfd8c9..ae8ffc20b 100644 --- a/tests/operators/functionals/conftest.py +++ b/tests/operators/functionals/conftest.py @@ -47,12 +47,12 @@ def result_dtype(self): def functional_test_cases(func: Callable[[FunctionalTestCase], None]) -> Callable[..., None]: """Decorator combining multiple parameterizations for test cases for all proximable functionals.""" - @pytest.mark.parametrize('shape', [[1, 2, 3]]) + @pytest.mark.parametrize('shape', [[1, 2, 3]], ids=['shape=[1,2,3]']) @pytest.mark.parametrize('dtype_name', ['float32', 'complex64']) @pytest.mark.parametrize('weight', ['scalar_weight', 'tensor_weight', 'complex_weight']) @pytest.mark.parametrize('target', ['no_target', 'random_target']) - @pytest.mark.parametrize('dim', [None]) - @pytest.mark.parametrize('divide_by_n', [True, False]) + @pytest.mark.parametrize('dim', [None], ids=['dim=None']) + @pytest.mark.parametrize('divide_by_n', [True, False], ids=['mean', 'sum']) @pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) def wrapper( functional: type[ElementaryProximableFunctional], diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 570fa2f1f..75fceacd2 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -8,7 +8,7 @@ SHAPE_VARIATIONS_SIGNAL_MODELS = pytest.mark.parametrize( ('parameter_shape', 'contrast_dim_shape', 'signal_shape'), [ - ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different inversion times + ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different contrast times ((1, 1, 10, 20, 30), (5, 1), (5, 1, 1, 10, 20, 30)), ((4, 1, 1, 10, 20, 30), (5, 1), (5, 4, 1, 1, 10, 20, 30)), # multiple maps along additional batch dimension ((4, 1, 1, 10, 20, 30), (5,), (5, 4, 1, 1, 10, 20, 30)), @@ -25,6 +25,24 @@ ((1,), (5,), (5, 1)), # single voxel ((4, 3, 1), (5, 4, 3), (5, 4, 3, 1)), ], + ids=[ + 'single_map_diff_contrast_times', + 'single_map_diff_contrast_times_2', + 'multiple_maps_additional_batch_dim', + 'multiple_maps_additional_batch_dim_2', + 'multiple_maps_additional_batch_dim_3', + 'multiple_maps_other_dim', + 'multiple_maps_other_dim_2', + 'multiple_maps_other_dim_3', + 'multiple_maps_other_and_batch_dim', + 'multiple_maps_other_and_batch_dim_2', + 'multiple_maps_other_and_batch_dim_3', + 'multiple_maps_other_and_batch_dim_4', + 'multiple_maps_other_and_batch_dim_5', + 'different_value_each_voxel', + 'single_voxel', + 'multiple_voxels', + ], ) diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 7caa13e7c..28c6e8860 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -17,10 +17,10 @@ def test_cart_sampling_op_data_match(): nkx = (1, 1, 1, 60) nky = (1, 1, 40, 1) nkz = (1, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + type_kx = 'uniform' + type_ky = 'uniform' + type_kz = 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Create matching data random_generator = RandomGenerator(seed=0) @@ -73,10 +73,10 @@ def test_cart_sampling_op_fwd_adj(sampling): nkx = (2, 1, 1, 60) nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) - sx = 'uf' - sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() # Subsample data and trajectory match sampling: diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 89a4bbc11..6f8c377cf 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -9,22 +9,24 @@ from tests.helper import dotproduct_adjointness_test -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @COMMON_MR_TRAJECTORIES -def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_fourier_op_fwd_adj_property( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test adjoint property of Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) @@ -42,26 +44,26 @@ def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, @pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ - # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), # Cartesian ky dimension defined along k2 rather than k1 - (5, 1, 18, 64), - 'nuf', - 'uf', - 'nuf', + ( # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 96, 1, 1), # nky - Cartesian ky dimension defined along k2 rather than k1 + (5, 1, 18, 64), # nkz + 'non-uniform', # type_kx + 'uniform', # type_ky + 'non-uniform', # type_kz ), ], + ids=['cartesian_fft_dims_not_aligned_with_k2_k1_k0_dims'], ) -def test_fourier_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Test trajectory not supported by Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 3db888897..054c99402 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -15,6 +15,7 @@ ((2, 2, 4), '... a b->... (a b)', (2, 8), {'b': 4}), # flatten ((2), '... (a b) -> ... a b', (2, 1), {'b': 1}), # unflatten ], + ids=['swap_axes', 'flatten', 'unflatten'], ) def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype): """Test adjointness and shape.""" From 72c18070d45bbad969bf2244b51163b10f5b627a Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Tue, 12 Nov 2024 15:31:57 +0100 Subject: [PATCH 22/23] Add CartesianSamplingOp to FourierOp (#482) Co-authored-by: Felix F Zimmermann <fzimmermann89@gmail.com> --- src/mrpro/operators/FourierOp.py | 41 +++++++++++++++++++----------- tests/conftest.py | 14 ++++++++++ tests/data/test_kdata.py | 13 ---------- tests/operators/test_fourier_op.py | 29 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index f4254d8d2..c57d1fbfd 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -11,6 +11,7 @@ from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil diff --git a/tests/conftest.py b/tests/conftest.py index 1fb7fb95f..3bd8946f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index ab4f5aabb..d5cfa0f0c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -16,19 +16,6 @@ from tests.phantoms import EllipsePhantomTestData -@pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" - ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' - ismrmrd_kdata = IsmrmrdRawTestData( - filename=ismrmrd_filename, - noise_level=0.0, - repetitions=3, - phantom=ellipse_phantom.phantom, - ) - return ismrmrd_kdata - - @pytest.fixture(scope='session') def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): """Fully sampled cartesian data set with bodycoil and surface coil data.""" diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 6f8c377cf..f48a24260 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,7 +1,9 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp from tests import RandomGenerator @@ -30,7 +32,11 @@ def test_fourier_op_fwd_adj_property( # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -70,3 +76,22 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = torch.randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted) From 202d395a2c10f2c3db7c491851988fd26dca14a5 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch <christoph.kolbitsch@ptb.de> Date: Tue, 12 Nov 2024 17:17:18 +0100 Subject: [PATCH 23/23] Exclude data outside of encoding_matrix (#234) Co-authored-by: Felix Zimmermann <fzimmermann89@gmail.com> --- src/mrpro/operators/CartesianSamplingOp.py | 98 ++++++++++++++++--- tests/conftest.py | 2 +- tests/operators/test_cartesian_sampling_op.py | 25 +++++ tests/operators/test_fourier_op.py | 6 +- 4 files changed, 115 insertions(+), 16 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 7a51924b1..07f8aba65 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -1,5 +1,7 @@ """Cartesian Sampling Operator.""" +import warnings + import torch from einops import rearrange, repeat @@ -7,6 +9,7 @@ from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator +from mrpro.utils.reshape import unsqueeze_left class CartesianSamplingOp(LinearOperator): @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + + # check that all points are inside the encoding matrix + inside_encoding_matrix = ( + ((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x)) + & ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y)) + & ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z)) + ) + if not torch.all(inside_encoding_matrix): + warnings.warn( + 'K-space points lie outside of the encoding_matrix and will be ignored.' + ' Increase the encoding_matrix to include these points.', + stacklevel=2, + ) + + inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)') + inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1] + inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1)) + self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx) + kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1) + else: + self._inside_encoding_matrix_idx: torch.Tensor | None = None + self.register_buffer('_fft_idx', kidx) + # we can skip the indexing if the data is already sorted self._needs_indexing = ( - not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + not torch.all(torch.diff(kidx) == 1) + or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + or self._inside_encoding_matrix_idx is not None ) self._trajectory_shape = traj.broadcasted_shape @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (x,) x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') - # take_along_dim does broadcast, so no need for extending here - x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) + # take_along_dim broadcasts, but needs the same number of dimensions + idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim) + x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1) + + if self._inside_encoding_matrix_idx is None: + # all trajectory points are inside the encoding matrix + x_indexed = x_inside_encoding_matrix + else: + # we need to add zeros + x_indexed = self._broadcast_and_scatter_along_last_dim( + x_inside_encoding_matrix, + self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3], + self._inside_encoding_matrix_idx, + ) + # reshape to (... other coil, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') - # scatter does not broadcast, so we need to manually broadcast the indices - broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) - idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1])) + if self._inside_encoding_matrix_idx is not None: + idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim) + y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1) - # although scatter_ is inplace, this will not cause issues with autograd, as self - # is always constant zero and gradients w.r.t. src work as expected. - y_scattered = torch.zeros( - *broadcast_shape, - self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, - dtype=y.dtype, - device=y.device, - ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) + y_scattered = self._broadcast_and_scatter_along_last_dim( + y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx + ) # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @staticmethod + def _broadcast_and_scatter_along_last_dim( + data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor + ) -> torch.Tensor: + """Broadcast scatter index and scatter into zero tensor. + + Parameters + ---------- + data_to_scatter + Data to be scattered at indices scatter_index + n_last_dim + Number of data points in last dimension + scatter_index + Indices describing where to scatter data + + Returns + ------- + Data scattered into tensor along scatter_index + """ + # scatter does not broadcast, so we need to manually broadcast the indices + broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1]) + idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1])) + + # although scatter_ is inplace, this will not cause issues with autograd, as self + # is always constant zero and gradients w.r.t. src work as expected. + data_scattered = torch.zeros( + *broadcast_shape, + n_last_dim, + dtype=data_to_scatter.dtype, + device=data_to_scatter.device, + ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) + + return data_scattered diff --git a/tests/conftest.py b/tests/conftest.py index 3bd8946f2..30ae9c229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -240,7 +240,7 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): k_list = [] for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): if spacing == 'non-uniform': - k = random_generator.float32_tensor(size=nk) + k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk) elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) elif spacing == 'zero': diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 28c6e8860..c1738b7bb 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -118,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) +@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) +def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): + """Test trajectory points outside of encoding_matrix.""" + encoding_matrix = SpatialDimension(40, 1, 20) + + # Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side + # The indices are inverted to ensure CartesianSamplingOp acts on them + kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx') + ky = torch.ones(1, 1, 1, 1) + kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1') + kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements + trajectory = KTrajectory(kz=kz, ky=ky, kx=kx) + + with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'): + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + + random_generator = RandomGenerator(seed=0) + u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1])) + v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx)) + + assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx + assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1]) diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index f48a24260..2d76642c3 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -73,7 +73,11 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)