From 9c28a48aeed9dafe6a249a5ddd1d5b62581e9ac0 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:41:06 +0200 Subject: [PATCH 01/10] Fix PyTest and Deployment workflow triggers (#467) --- .github/workflows/deployment.yml | 49 +++++++++++++++++++------------- .github/workflows/pytest.yml | 48 ++++++++++++++++--------------- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index 9f02e8dcc..dd900496e 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -1,4 +1,4 @@ -name: Publish to PyPI +name: Build and Deploy Package on: push: @@ -7,21 +7,31 @@ on: jobs: build-testpypi-package: - name: Modify version and build package for TestPyPI + name: Build Package for TestPyPI runs-on: ubuntu-latest outputs: version: ${{ steps.set_suffix.outputs.version }} suffix: ${{ steps.set_suffix.outputs.suffix }} + version_changed: ${{ steps.changes.outputs.version_changed }} steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.x" - - name: Set environment variable for version suffix + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Check if VERSION file is modified compared to main + uses: dorny/paths-filter@v3 + id: changes + with: + base: main + filters: | + version_changed: + - 'src/mrpro/VERSION' + + - name: Set Version Suffix id: set_suffix run: | VERSION=$(cat src/mrpro/VERSION) @@ -30,12 +40,12 @@ jobs: echo "suffix=$SUFFIX" >> $GITHUB_OUTPUT echo "version=$VERSION" >> $GITHUB_OUTPUT - - name: Build package with version suffix + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store TestPyPi distribution + - name: Upload TestPyPI Distribution Artifact uses: actions/upload-artifact@v4 with: name: testpypi-package-distribution @@ -46,6 +56,7 @@ jobs: needs: - build-testpypi-package runs-on: ubuntu-latest + if: needs.build-testpypi-package.outputs.version_changed == 'true' environment: name: testpypi @@ -55,7 +66,7 @@ jobs: id-token: write steps: - - name: Download TestPyPi distribution + - name: Download TestPyPI Distribution uses: actions/download-artifact@v4 with: name: testpypi-package-distribution @@ -68,7 +79,7 @@ jobs: verbose: true test-install-from-testpypi: - name: Test installation from TestPyPI + name: Test Installation from TestPyPI needs: - testpypi-deployment - build-testpypi-package @@ -86,14 +97,14 @@ jobs: python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ build-pypi-package: - name: Build package for PyPI + name: Build Package for PyPI runs-on: ubuntu-latest needs: - test-install-from-testpypi outputs: version: ${{ steps.get_version.outputs.version }} steps: - - name: Checkout code + - name: Checkout Code uses: actions/checkout@v4 - name: Set up Python @@ -101,22 +112,22 @@ jobs: with: python-version: "3.x" - - name: Install setuptools_git_versioning + - name: Install Automatic Versioning Tool run: | python -m pip install setuptools-git-versioning - - name: Get current version + - name: Get Current Version id: get_version run: | VERSION=$(python -m setuptools_git_versioning) echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - - name: Build package + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store PyPi distribution + - name: Store PyPI Distribution uses: actions/upload-artifact@v4 with: name: pypi-package-distribution @@ -138,13 +149,13 @@ jobs: id-token: write steps: - - name: Download PyPi distribution + - name: Download PyPI Distribution uses: actions/download-artifact@v4 with: name: pypi-package-distribution path: dist/ - - name: Create tag + - name: Create Tag uses: actions/github-script@v7 with: script: | @@ -155,7 +166,7 @@ jobs: sha: context.sha }) - - name: Create release + - name: Create Release uses: actions/github-script@v7 with: github-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 427b62e96..53d1343c7 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -2,17 +2,21 @@ name: PyTest on: pull_request: + push: + branches: + - main jobs: get_dockerfiles: - name: Get list of dockerfiles for different containers + name: Get List of Dockerfiles for Containers runs-on: ubuntu-latest permissions: packages: read outputs: imagenames: ${{ steps.set-matrix.outputs.imagenames }} steps: - - id: set-matrix + - name: Retrieve Docker Image Names + id: set-matrix env: GH_TOKEN: ${{ secrets.GHCR_TOKEN }} run: | @@ -35,58 +39,56 @@ jobs: echo "image names with tag latest: $imagenames_latest" echo "imagenames=$imagenames_latest" >> $GITHUB_OUTPUT - - name: Dockerfile overview + - name: Dockerfile Overview run: | - echo "final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" + echo "Final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" test: - name: Run tests and get coverage report + name: Run Tests and Coverage Report needs: get_dockerfiles runs-on: ubuntu-latest permissions: pull-requests: write contents: write strategy: - fail-fast: false + fail-fast: false matrix: imagename: ${{ fromJson(needs.get_dockerfiles.outputs.imagenames) }} - # runs within Docker container container: image: ghcr.io/ptb-mr/${{ matrix.imagename }}:latest options: --user runner - steps: - - name: Checkout repo + - name: Checkout Repository uses: actions/checkout@v4 - - name: Install mrpro and dependencies - run: pip install --upgrade --upgrade-strategy "eager" .[test] + - name: Install MRpro and Dependencies + run: pip install --upgrade --upgrade-strategy eager .[test] - - name: Install pytest-github-actions-annotate-failures plugin + - name: Install PyTest GitHub Annotation Plugin run: pip install pytest-github-actions-annotate-failures - - name: Run PyTest + - name: Run PyTest and Generate Coverage Report run: | - pytest -n 4 -m "not cuda" --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt + pytest -n 4 -m "not cuda" --junitxml=pytest.xml \ + --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt - - name: Check for pytest.xml + - name: Verify PyTest XML Output run: | - if [ -f pytest.xml ]; then - echo "pytest.xml file found. Continuing..." - else - echo "pytest.xml file not found. Please check previous 'Run PyTest' section for errors." + if [ ! -f pytest.xml ]; then + echo "PyTest XML report not found. Please check the previous 'Run PyTest' step for errors." exit 1 fi - - name: Pytest coverage comment + - name: Post PyTest Coverage Comment id: coverageComment uses: MishaKav/pytest-coverage-comment@v1.1.53 with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml - - name: Create the Badge + - name: Create Coverage Badge on Main Branch Push uses: schneegans/dynamic-badges-action@v1.7.0 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' with: auth: ${{ secrets.GIST_SECRET }} gistID: 48e334a10caf60e6708d7c712e56d241 @@ -96,12 +98,12 @@ jobs: color: ${{ steps.coverageComment.outputs.color }} namedLogo: python - - name: Set pipeline status + - name: Set Pipeline Status Based on Test Results if: steps.coverageComment.outputs.errors != 0 || steps.coverageComment.outputs.failures != 0 uses: actions/github-script@v7 with: script: | - core.setFailed('PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.') + core.setFailed("PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.") concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} From ec503af94aef49298b194d9a14ff476cf880930b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 24 Oct 2024 10:44:00 +0200 Subject: [PATCH 02/10] Remove superfluous .forward (#472) --- examples/pulseq_2d_radial_golden_angle.ipynb | 6 +++--- examples/pulseq_2d_radial_golden_angle.py | 6 +++--- src/mrpro/operators/FourierOp.py | 2 +- tests/operators/models/test_inversion_recovery.py | 4 ++-- tests/operators/models/test_molli.py | 4 ++-- tests/operators/models/test_mono_exponential_decay.py | 4 ++-- tests/operators/models/test_saturation_recovery.py | 4 ++-- .../test_transient_steady_state_with_preparation.py | 6 +++--- tests/operators/models/test_wasabi.py | 8 ++++---- tests/operators/models/test_wasabiti.py | 10 +++++----- tests/operators/test_sensitivity_op.py | 2 +- tests/operators/test_zero_pad_op.py | 2 +- 12 files changed, 29 insertions(+), 29 deletions(-) diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index d0e6e0adb..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -76,7 +76,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_ismrmrd_traj = direct_reconstruction.forward(kdata)" + "img_using_ismrmrd_traj = direct_reconstruction(kdata)" ] }, { @@ -100,7 +100,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_rad2d_traj = direct_reconstruction.forward(kdata)" + "img_using_rad2d_traj = direct_reconstruction(kdata)" ] }, { @@ -143,7 +143,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_pulseq_traj = direct_reconstruction.forward(kdata)" + "img_using_pulseq_traj = direct_reconstruction(kdata)" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 955b20a37..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -37,7 +37,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_ismrmrd_traj = direct_reconstruction.forward(kdata) +img_using_ismrmrd_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryRadial2D @@ -49,7 +49,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_rad2d_traj = direct_reconstruction.forward(kdata) +img_using_rad2d_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryPulseq @@ -73,7 +73,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_pulseq_traj = direct_reconstruction.forward(kdata) +img_using_pulseq_traj = direct_reconstruction(kdata) # %% [markdown] # ### Plot the different reconstructed images diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 79351da76..f4254d8d2 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """ if len(self._fft_dims): # FFT - (x,) = self._fast_fourier_op.forward(x) + (x,) = self._fast_fourier_op(x) if self._nufft_dims: # we need to move the nufft-dimensions to the end and flatten all other dimensions diff --git a/tests/operators/models/test_inversion_recovery.py b/tests/operators/models/test_inversion_recovery.py index 71bb4dfe6..52f957f8f 100644 --- a/tests/operators/models/test_inversion_recovery.py +++ b/tests/operators/models/test_inversion_recovery.py @@ -21,7 +21,7 @@ def test_inversion_recovery(ti, result): """ model = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) # Assert closeness to -m0 for ti=0 if result == '-m0': @@ -37,5 +37,5 @@ def test_inversion_recovery_shape(parameter_shape, contrast_dim_shape, signal_sh (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_molli.py b/tests/operators/models/test_molli.py index 8ee5f9117..c92f92556 100644 --- a/tests/operators/models/test_molli.py +++ b/tests/operators/models/test_molli.py @@ -27,7 +27,7 @@ def test_molli(ti, result): # Generate signal model and torch tensor for comparison model = MOLLI(ti) - (image,) = model.forward(a, c, t1) + (image,) = model(a, c, t1) # Assert closeness to a(1-c) for large ti if result == 'a(1-c)': @@ -43,5 +43,5 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape): (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MOLLI(ti) a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(a, c, t1) + (signal,) = model_op(a, c, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_mono_exponential_decay.py b/tests/operators/models/test_mono_exponential_decay.py index c56e86e12..d77d5862c 100644 --- a/tests/operators/models/test_mono_exponential_decay.py +++ b/tests/operators/models/test_mono_exponential_decay.py @@ -21,7 +21,7 @@ def test_mono_exponential_decay(decay_time, result): """ model = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples() - (image,) = model.forward(m0, decay_constant) + (image,) = model(m0, decay_constant) zeros = torch.zeros_like(m0) @@ -39,5 +39,5 @@ def test_mono_exponential_decay_shape(parameter_shape, contrast_dim_shape, signa (decay_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, decay_constant) + (signal,) = model_op(m0, decay_constant) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_saturation_recovery.py b/tests/operators/models/test_saturation_recovery.py index 1a821282b..692d4cc31 100644 --- a/tests/operators/models/test_saturation_recovery.py +++ b/tests/operators/models/test_saturation_recovery.py @@ -21,7 +21,7 @@ def test_saturation_recovery(ti, result): """ model = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) zeros = torch.zeros_like(m0) @@ -39,5 +39,5 @@ def test_saturation_recovery_shape(parameter_shape, contrast_dim_shape, signal_s (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index 87d21a336..2b9d2e451 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -23,7 +23,7 @@ def test_transient_steady_state(sampling_time, m0_scaling_preparation, result): repetition_time = 5 model = TransientSteadyStateWithPreparation(sampling_time, repetition_time, m0_scaling_preparation) m0, t1, flip_angle = create_parameter_tensor_tuples(number_of_tensors=3) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) # Assert closeness to m0 if result == 'm0': @@ -56,7 +56,7 @@ def test_transient_steady_state_inversion_recovery(): analytical_signal = m0 * (1 - 2 * torch.exp(-(sampling_time / t1))) model = TransientSteadyStateWithPreparation(sampling_time, repetition_time=100, m0_scaling_preparation=-1) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) torch.testing.assert_close(signal, analytical_signal) @@ -77,5 +77,5 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation ) m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(m0, t1, flip_angle) + (signal,) = model_op(m0, t1, flip_angle) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 1aa36384e..6b4494539 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -14,11 +14,11 @@ def test_WASABI_shift(): """Test symmetry property of shifted and unshifted WASABI spectra.""" offsets_unshifted, b0_shift, rb1, c, d = create_data() wasabi_model = WASABI(offsets=offsets_unshifted) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) offsets_shifted, b0_shift, rb1, c, d = create_data(b0_shift=100) wasabi_model = WASABI(offsets=offsets_shifted) - (signal_shifted,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal_shifted,) = wasabi_model(b0_shift, rb1, c, d) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -30,7 +30,7 @@ def test_WASABI_shift(): def test_WASABI_extreme_offset(): offset, b0_shift, rb1, c, d = create_data(offset_max=30000, n_offsets=1) wasabi_model = WASABI(offsets=offset) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) assert torch.isclose(signal, torch.tensor([1.0])), 'For an extreme offset, the signal should be unattenuated' @@ -41,5 +41,5 @@ def test_WASABI_shape(parameter_shape, contrast_dim_shape, signal_shape): (offsets,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = WASABI(offsets) b0_shift, rb1, c, d = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=4) - (signal,) = model_op.forward(b0_shift, rb1, c, d) + (signal,) = model_op(b0_shift, rb1, c, d) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index bd451a03b..de2cacc50 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -15,7 +15,7 @@ def test_WASABITI_symmetry(): """Test symmetry property of complete WASABITI spectra.""" offsets, b0_shift, rb1, t1 = create_data() wasabiti_model = WASABITI(offsets=offsets, trec=torch.ones_like(offsets)) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) # check that all values are symmetric around the center assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center' @@ -26,7 +26,7 @@ def test_WASABITI_symmetry_after_shift(): offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100) trec = torch.ones_like(offsets_shifted) wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) - (signal_shifted,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -42,7 +42,7 @@ def test_WASABITI_asymmetry_for_non_unique_trec(): trec[: len(offsets_unshifted) // 2] = 2.0 wasabiti_model = WASABITI(offsets=offsets_unshifted, trec=trec) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center' @@ -53,7 +53,7 @@ def test_WASABITI_relaxation_term(t1): offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) trec = torch.ones_like(offset) * t1 wasabiti_model = WASABITI(offsets=offset, trec=trec) - sig = wasabiti_model.forward(b0_shift, rb1, t1) + sig = wasabiti_model(b0_shift, rb1, t1) assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8) @@ -72,5 +72,5 @@ def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) model_op = WASABITI(ti, trec) b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(b0_shift, rb1, t1) + (signal,) = model_op(b0_shift, rb1, t1) assert signal.shape == signal_shape diff --git a/tests/operators/test_sensitivity_op.py b/tests/operators/test_sensitivity_op.py index 25efc5bee..1321498d8 100644 --- a/tests/operators/test_sensitivity_op.py +++ b/tests/operators/test_sensitivity_op.py @@ -90,7 +90,7 @@ def test_sensitivity_op_other_dim_compatibility_fail(n_other_csm, n_other_img): # Apply to n_other_img shape u = random_generator.complex64_tensor(size=(n_other_img, 1, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): - sensitivity_op.forward(u) + sensitivity_op(u) v = random_generator.complex64_tensor(size=(n_other_img, n_coils, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index a56635f5f..ce2d8855d 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -20,7 +20,7 @@ def test_zero_pad_op_content(): original_shape=tuple([original_shape[d] for d in padding_dimensions]), padded_shape=tuple([padded_shape[d] for d in padding_dimensions]), ) - (padded_data,) = zero_padding_op.forward(original_data) + (padded_data,) = zero_padding_op(original_data) # Compare overlapping region torch.testing.assert_close(original_data[:, 10:90, :, 50:150, :, :], padded_data[:, :, :, :, 95:145, :]) From eecdf490f6bb11c147de8e680f8581e23156fd47 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Thu, 24 Oct 2024 13:13:32 +0200 Subject: [PATCH 03/10] Remove superfluous .forward (#475) Co-authored-by: Felix Zimmermann --- src/mrpro/operators/FastFourierOp.py | 2 +- tests/operators/test_fast_fourier_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 53f3f6eb4..4ffe7e3a6 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: FFT of x """ y = torch.fft.fftshift( - torch.fft.fftn(torch.fft.ifftshift(*self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + torch.fft.fftn(torch.fft.ifftshift(*self._pad_op(x), dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim, ) return (y,) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 7ac1a8211..7e2f47fd7 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -28,7 +28,7 @@ def test_fast_fourier_op_forward(npoints, a): # Transform image to k-space ff_op = FastFourierOp(dim=(0,)) - (igauss_fwd,) = ff_op.forward(igauss) + (igauss_fwd,) = ff_op(igauss) # Scaling to "undo" fft scaling igauss_fwd *= np.sqrt(npoints) / 2 From bbdc97434828780d6ae7016d27353361e4d35ad1 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Sat, 26 Oct 2024 19:53:53 +0200 Subject: [PATCH 04/10] Add RegularizedIterativeSENSEReconstruction (#388) Co-authored-by: NAME --- ...rized_iterative_sense_reconstruction.ipynb | 389 ++++++++++++++++++ ...ularized_iterative_sense_reconstruction.py | 193 +++++++++ .../IterativeSENSEReconstruction.py | 87 +--- .../reconstruction/Reconstruction.py | 4 +- ...RegularizedIterativeSENSEReconstruction.py | 139 +++++++ .../algorithms/reconstruction/__init__.py | 6 +- 6 files changed, 733 insertions(+), 85 deletions(-) create mode 100644 examples/regularized_iterative_sense_reconstruction.ipynb create mode 100644 examples/regularized_iterative_sense_reconstruction.py create mode 100644 src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb new file mode 100644 index 000000000..6b1c2704b --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af432293", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Regularized Iterative SENSE Reconstruction of 2D golden angle radial data\n", + "Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7a6ce3", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# define zenodo URL of the example ismrmd data\n", + "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd8486b", + "metadata": {}, + "outputs": [], + "source": [ + "# Download raw data\n", + "import tempfile\n", + "\n", + "import requests\n", + "\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "6a9defa1", + "metadata": {}, + "source": [ + "### Image reconstruction\n", + "We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data.\n", + "RegularizedIterativeSENSEReconstruction solves the following reconstruction problem:\n", + "\n", + "Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms,\n", + "coil sensitivity maps...) $A$ then we can formulate the forward problem as:\n", + "\n", + "$ y = Ax + n $\n", + "\n", + "where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 $\n", + "\n", + "where $W^\\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal\n", + "operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain\n", + "a solution with certain properties:\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$\n", + "\n", + "where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image.\n", + "With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$.\n", + "\n", + "Setting the derivative of the functional $F$ to zero and rearranging yields\n", + "\n", + "$ (A^H W A + l B) x = A^H W y + l x_{reg}$\n", + "\n", + "which is a linear system $Hx = b$ that needs to be solved for $x$.\n", + "\n", + "One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution\n", + "dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks\n", + "has been used as an image regulariser.\n", + "\n", + "In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image.\n", + "Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using\n", + "only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the\n", + "regularization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4da15c2", + "metadata": {}, + "outputs": [], + "source": [ + "import mrpro" + ] + }, + { + "cell_type": "markdown", + "id": "de055070", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Read-in the raw data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ac1d89f", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.data import KData\n", + "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", + "\n", + "# Load in the Data and the trajectory from the ISMRMRD file\n", + "kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd())\n", + "kdata.header.recon_matrix.x = 256\n", + "kdata.header.recon_matrix.y = 256" + ] + }, + { + "cell_type": "markdown", + "id": "1f389140", + "metadata": {}, + "source": [ + "##### Image $x_{reg}$ from fully sampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212b915c", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction\n", + "from mrpro.data import CsmData\n", + "\n", + "# Estimate coil maps\n", + "direct_reconstruction = DirectReconstruction(kdata, csm=None)\n", + "img_coilwise = direct_reconstruction(kdata)\n", + "csm = CsmData.from_idata_walsh(img_coilwise)\n", + "\n", + "# Iterative SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3)\n", + "img_iterative_sense = iterative_sense_reconstruction(kdata)" + ] + }, + { + "cell_type": "markdown", + "id": "bec6b712", + "metadata": {}, + "source": [ + "##### Image $x$ from undersampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6740447", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Data undersampling, i.e. take only the first 20 radial lines\n", + "idx_us = torch.arange(0, 20)[None, :]\n", + "kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fbfd664", + "metadata": {}, + "outputs": [], + "source": [ + "# Iterativ SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6)\n", + "img_us_iterative_sense = iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041ffe72", + "metadata": {}, + "outputs": [], + "source": [ + "# Regularized iterativ SENSE reconstruction\n", + "from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction\n", + "\n", + "regularization_weight = 1.0\n", + "n_iterations = 6\n", + "regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction(\n", + " kdata_us,\n", + " csm=csm,\n", + " n_iterations=n_iterations,\n", + " regularization_data=img_iterative_sense.data,\n", + " regularization_weight=regularization_weight,\n", + ")\n", + "img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5bbec1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()]\n", + "vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4))\n", + "for ind in range(3):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "2bd49c87", + "metadata": {}, + "source": [ + "### Behind the scenes" + ] + }, + { + "cell_type": "markdown", + "id": "53779251", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", + "\n", + "This is very similar to the iterative SENSE reconstruction. For more detail please look at the\n", + "iterative_sense_reconstruction notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e985a4f3", + "metadata": {}, + "outputs": [], + "source": [ + "dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator()\n", + "fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us)\n", + "csm_operator = csm.as_operator()\n", + "acquisition_operator = fourier_operator @ csm_operator" + ] + }, + { + "cell_type": "markdown", + "id": "2daa0fee", + "metadata": {}, + "source": [ + "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac1d5fb4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "right_hand_side = (\n", + " acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76a0b153", + "metadata": {}, + "source": [ + "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5effb592", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.operators import IdentityOp\n", + "\n", + "operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor(\n", + " regularization_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f24a8588", + "metadata": {}, + "source": [ + "##### Run conjugate gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96827838", + "metadata": {}, + "outputs": [], + "source": [ + "img_manual = mrpro.algorithms.optimizers.cg(\n", + " operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c065a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the reconstructed image\n", + "vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]]\n", + "vis_title = ['Regularized Iterative SENSE R=20', '\"Manual\" Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4))\n", + "for ind in range(2):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "d6d7efdf", + "metadata": {}, + "source": [ + "### Check for equal results\n", + "The two versions should result in the same image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f59b6015", + "metadata": {}, + "outputs": [], + "source": [ + "# If the assert statement did not raise an exception, the results are equal.\n", + "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" + ] + }, + { + "cell_type": "markdown", + "id": "6ecd6e70", + "metadata": {}, + "source": [ + "### Next steps\n", + "Play around with the regularization_weight to see how it effects the final image quality.\n", + "\n", + "Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications\n", + "we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the\n", + "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", + "compared to the unregularised images." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py new file mode 100644 index 000000000..e41dc4ac5 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -0,0 +1,193 @@ +# %% [markdown] +# # Regularized Iterative SENSE Reconstruction of 2D golden angle radial data +# Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data +# %% +# define zenodo URL of the example ismrmd data +zenodo_url = 'https://zenodo.org/records/10854057/files/' +fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' +# %% +# Download raw data +import tempfile + +import requests + +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() + +# %% [markdown] +# ### Image reconstruction +# We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data. +# RegularizedIterativeSENSEReconstruction solves the following reconstruction problem: +# +# Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms, +# coil sensitivity maps...) $A$ then we can formulate the forward problem as: +# +# $ y = Ax + n $ +# +# where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$ +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 $ +# +# where $W^\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal +# operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain +# a solution with certain properties: +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$ +# +# where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image. +# With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$. +# +# Setting the derivative of the functional $F$ to zero and rearranging yields +# +# $ (A^H W A + l B) x = A^H W y + l x_{reg}$ +# +# which is a linear system $Hx = b$ that needs to be solved for $x$. +# +# One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution +# dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks +# has been used as an image regulariser. +# +# In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image. +# Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using +# only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the +# regularization. + +# %% +import mrpro + +# %% [markdown] +# ##### Read-in the raw data +# %% +from mrpro.data import KData +from mrpro.data.traj_calculators import KTrajectoryIsmrmrd + +# Load in the Data and the trajectory from the ISMRMRD file +kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd()) +kdata.header.recon_matrix.x = 256 +kdata.header.recon_matrix.y = 256 + +# %% [markdown] +# ##### Image $x_{reg}$ from fully sampled data + +# %% +from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction +from mrpro.data import CsmData + +# Estimate coil maps +direct_reconstruction = DirectReconstruction(kdata, csm=None) +img_coilwise = direct_reconstruction(kdata) +csm = CsmData.from_idata_walsh(img_coilwise) + +# Iterative SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3) +img_iterative_sense = iterative_sense_reconstruction(kdata) + +# %% [markdown] +# ##### Image $x$ from undersampled data + +# %% +import torch + +# Data undersampling, i.e. take only the first 20 radial lines +idx_us = torch.arange(0, 20)[None, :] +kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition') + +# %% +# Iterativ SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6) +img_us_iterative_sense = iterative_sense_reconstruction(kdata_us) + +# %% +# Regularized iterativ SENSE reconstruction +from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction + +regularization_weight = 1.0 +n_iterations = 6 +regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction( + kdata_us, + csm=csm, + n_iterations=n_iterations, + regularization_data=img_iterative_sense.data, + regularization_weight=regularization_weight, +) +img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us) + +# %% +import matplotlib.pyplot as plt + +vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()] +vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4)) +for ind in range(3): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + + +# %% [markdown] +# ### Behind the scenes + +# %% [markdown] +# ##### Set-up the density compensation operator $W$ and acquisition model $A$ +# +# This is very similar to the iterative SENSE reconstruction. For more detail please look at the +# iterative_sense_reconstruction notebook. +# %% +dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator() +fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us) +csm_operator = csm.as_operator() +acquisition_operator = fourier_operator @ csm_operator + +# %% [markdown] +# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ + +# %% +right_hand_side = ( + acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data +) + + +# %% [markdown] +# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ + +# %% +from mrpro.operators import IdentityOp + +operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor( + regularization_weight +) + +# %% [markdown] +# ##### Run conjugate gradient + +# %% +img_manual = mrpro.algorithms.optimizers.cg( + operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0 +) + +# %% +# Display the reconstructed image +vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]] +vis_title = ['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4)) +for ind in range(2): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + +# %% [markdown] +# ### Check for equal results +# The two versions should result in the same image data. + +# %% +# If the assert statement did not raise an exception, the results are equal. +assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) + +# %% [markdown] +# ### Next steps +# Play around with the regularization_weight to see how it effects the final image quality. +# +# Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications +# we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the +# streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality +# compared to the unregularised images. diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 04d37128c..32785a91a 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -4,20 +4,17 @@ from collections.abc import Callable -import torch - -from mrpro.algorithms.optimizers.cg import cg -from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( + RegularizedIterativeSENSEReconstruction, +) from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData -from mrpro.data.IData import IData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator -class IterativeSENSEReconstruction(DirectReconstruction): +class IterativeSENSEReconstruction(RegularizedIterativeSENSEReconstruction): r"""Iterative SENSE reconstruction. This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2` @@ -49,6 +46,8 @@ def __init__( ) -> None: """Initialize IterativeSENSEReconstruction. + For a regularized version of the iterative SENSE algorithm please see RegularizedIterativeSENSEReconstruction. + Parameters ---------- kdata @@ -74,76 +73,4 @@ def __init__( ValueError If the kdata and fourier_op are None or if csm is a Callable but kdata is None. """ - super().__init__(kdata, fourier_op, csm, noise, dcf) - self.n_iterations = n_iterations - - def _self_adjoint_operator(self) -> LinearOperator: - """Create the self-adjoint operator. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Create the self-adjoint operator as :math:`H = A^H W A` if the DCF is not None otherwise as :math:`H = A^H A`. - """ - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Create H = A^H W A - operator = operator.H @ dcf_operator @ operator - else: - # Create H = A^H A - operator = operator.H @ operator - - return operator - - def _right_hand_side(self, kdata: KData) -> torch.Tensor: - """Calculate the right-hand-side of the normal equation. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Calculate the right-hand-side as :math:`b = A^H W y` if the DCF is not None otherwise as :math:`b = A^H y`. - - Parameters - ---------- - kdata - k-space data to reconstruct. - """ - device = kdata.data.device - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Calculate b = A^H W y - (right_hand_side,) = operator.to(device).H(dcf_operator(kdata.data)[0]) - else: - # Calculate b = A^H y - (right_hand_side,) = operator.to(device).H(kdata.data) - - return right_hand_side - - def forward(self, kdata: KData) -> IData: - """Apply the reconstruction. - - Parameters - ---------- - kdata - k-space data to reconstruct. - - Returns - ------- - the reconstruced image. - """ - device = kdata.data.device - if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) - - operator = self._self_adjoint_operator().to(device) - right_hand_side = self._right_hand_side(kdata) - - img_tensor = cg( - operator, right_hand_side, initial_value=right_hand_side, max_iterations=self.n_iterations, tolerance=0.0 - ) - img = IData.from_tensor_and_kheader(img_tensor, kdata.header) - return img + super().__init__(kdata, fourier_op, csm, noise, dcf, n_iterations=n_iterations, regularization_weight=0) diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index 127be8d5f..c4208157e 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -101,15 +101,13 @@ def direct_reconstruction(self, kdata: KData) -> IData: ------- image data """ - device = kdata.data.device if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) + kdata = prewhiten_kspace(kdata, self.noise) operator = self.fourier_op if self.csm is not None: operator = operator @ self.csm.as_operator() if self.dcf is not None: operator = self.dcf.as_operator() @ operator - operator = operator.to(device) (img_tensor,) = operator.H(kdata.data) img = IData.from_tensor_and_kheader(img_tensor, kdata.header) return img diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py new file mode 100644 index 000000000..c9a307ebe --- /dev/null +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -0,0 +1,139 @@ +"""Regularized Iterative SENSE Reconstruction by adjoint Fourier transform.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace +from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.data._kdata.KData import KData +from mrpro.data.CsmData import CsmData +from mrpro.data.DcfData import DcfData +from mrpro.data.IData import IData +from mrpro.data.KNoise import KNoise +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator + + +class RegularizedIterativeSENSEReconstruction(DirectReconstruction): + r"""Regularized iterative SENSE reconstruction. + + This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2 + + \frac{1}{2}L||Bx - x_0||_2^2` + by using a conjugate gradient algorithm to solve + :math:`H x = b` with :math:`H = A^H W A + L B` and :math:`b = A^H W y + L x_0` where :math:`A` + is the acquisition model (coil sensitivity maps, Fourier operator, k-space sampling), :math:`y` is the acquired + k-space data, :math:`W` describes the density compensation, :math:`L` is the strength of the regularization and + :math:`x_0` is the regularization image (i.e. the prior). :math:`B` is a linear operator applied to :math:`x`. + """ + + n_iterations: int + """Number of CG iterations.""" + + regularization_data: torch.Tensor + """Regularization data (i.e. prior) :math:`x_0`.""" + + regularization_weight: torch.Tensor + """Strength of the regularization :math:`L`.""" + + regularization_op: LinearOperator + """Linear operator :math:`B` applied to the current estimate in the regularization term.""" + + def __init__( + self, + kdata: KData | None = None, + fourier_op: LinearOperator | None = None, + csm: Callable | CsmData | None = CsmData.from_idata_walsh, + noise: KNoise | None = None, + dcf: DcfData | None = None, + *, + n_iterations: int = 5, + regularization_data: float | torch.Tensor = 0.0, + regularization_weight: float | torch.Tensor, + regularization_op: LinearOperator | None = None, + ) -> None: + """Initialize RegularizedIterativeSENSEReconstruction. + + For a unregularized version of the iterative SENSE algorithm the regularization_weight can be set to 0 or + IterativeSENSEReconstruction algorithm can be used. + + Parameters + ---------- + kdata + KData. If kdata is provided and fourier_op or dcf are None, then fourier_op and dcf are estimated based on + kdata. Otherwise fourier_op and dcf are used as provided. + fourier_op + Instance of the FourierOperator used for reconstruction. If None, set up based on kdata. + csm + Sensitivity maps for coil combination. If None, no coil combination is carried out, i.e. images for each + coil are returned. If a callable is provided, coil images are reconstructed using the adjoint of the + FourierOperator (including density compensation) and then sensitivity maps are calculated using the + callable. For this, kdata needs also to be provided. For examples have a look at the CsmData class + e.g. from_idata_walsh or from_idata_inati. + noise + KNoise used for prewhitening. If None, no prewhitening is performed + dcf + K-space sampling density compensation. If None, set up based on kdata. + n_iterations + Number of CG iterations + regularization_data + Regularization data, e.g. a reference image (:math:`x_0`). + regularization_weight + Strength of the regularization (:math:`L`). + regularization_op + Linear operator :math:`B` applied to the current estimate in the regularization term. If None, nothing is + applied to the current estimate. + + + Raises + ------ + ValueError + If the kdata and fourier_op are None or if csm is a Callable but kdata is None. + """ + super().__init__(kdata, fourier_op, csm, noise, dcf) + self.n_iterations = n_iterations + self.regularization_data = torch.as_tensor(regularization_data) + self.regularization_weight = torch.as_tensor(regularization_weight) + self.regularization_op = regularization_op if regularization_op is not None else IdentityOp() + + def forward(self, kdata: KData) -> IData: + """Apply the reconstruction. + + Parameters + ---------- + kdata + k-space data to reconstruct. + + Returns + ------- + the reconstruced image. + """ + if self.noise is not None: + kdata = prewhiten_kspace(kdata, self.noise) + + # Create the normal operator as H = A^H W A if the DCF is not None otherwise as H = A^H A. + # The acquisition model is A = F S if the CSM S is defined otherwise A = F with the Fourier operator F + csm_op = self.csm.as_operator() if self.csm is not None else IdentityOp() + precondition_op = self.dcf.as_operator() if self.dcf is not None else IdentityOp() + operator = (self.fourier_op @ csm_op).H @ precondition_op @ (self.fourier_op @ csm_op) + + # Calculate the right-hand-side as b = A^H W y if the DCF is not None otherwise as b = A^H y. + (right_hand_side,) = (self.fourier_op @ csm_op).H(precondition_op(kdata.data)[0]) + + # Add regularization + if not torch.all(self.regularization_weight == 0): + operator = operator + IdentityOp() @ (self.regularization_weight * self.regularization_op) + right_hand_side += self.regularization_weight * self.regularization_data + + img_tensor = cg( + operator, + right_hand_side, + initial_value=right_hand_side, + max_iterations=self.n_iterations, + tolerance=0.0, + ) + img = IData.from_tensor_and_kheader(img_tensor, kdata.header) + return img diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 180186dd4..38b539f8b 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,8 +1,10 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import RegularizedIterativeSENSEReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction __all__ = [ "DirectReconstruction", "IterativeSENSEReconstruction", - "Reconstruction" -] \ No newline at end of file + "Reconstruction", + "RegularizedIterativeSENSEReconstruction" +] From 87cdaf1fcda5c537405c4fd73a363e0bde706a7f Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 29 Oct 2024 10:56:32 +0100 Subject: [PATCH 05/10] Release 0.241029 (#491) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0a67c464f..0f6ae6fb6 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241015 +0.241029 From 8b73303fb31779e10909d0ccccfa9fd64122fb36 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 30 Oct 2024 11:33:11 +0100 Subject: [PATCH 06/10] Fix unrolled CG autograd (#492) --- pyproject.toml | 1 + src/mrpro/algorithms/optimizers/cg.py | 11 +++-------- tests/algorithms/test_cg.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d2780589..31798a35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", + "ignore:Anomaly Detection has been enabled:UserWarning", #torch.autograd ] addopts = "-n auto" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 54cb1d782..b879796b0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -81,11 +81,6 @@ def cg( if torch.vdot(residual.flatten(), residual.flatten()) == 0: return solution - # squared tolerance; - # (we will check ||residual||^2 < tolerance^2 instead of ||residual|| < tol - # to avoid the computation of the root for the norm) - tolerance_squared = tolerance**2 - # dummy value. new value will be set in loop before first usage residual_norm_squared_previous = None @@ -95,7 +90,7 @@ def cg( residual_norm_squared = torch.vdot(residual_flat, residual_flat).real # check if the solution is already accurate enough - if tolerance != 0 and (residual_norm_squared < tolerance_squared): + if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution if iteration > 0: @@ -105,8 +100,8 @@ def cg( # update estimates of the solution and the residual (operator_conjugate_vector,) = operator(conjugate_vector) alpha = residual_norm_squared / (torch.vdot(conjugate_vector.flatten(), operator_conjugate_vector.flatten())) - solution += alpha * conjugate_vector - residual -= alpha * operator_conjugate_vector + solution = solution + alpha * conjugate_vector + residual = residual - alpha * operator_conjugate_vector residual_norm_squared_previous = residual_norm_squared diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 16c16e78b..8a4434e2a 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -145,3 +145,13 @@ def callback(cg_status: CGStatus) -> None: assert True cg(h_operator, right_hand_side, callback=callback) + + +def test_autograd(system): + """Test autograd through cg""" + h_operator, right_hand_side, _ = system + right_hand_side.requires_grad_(True) + with torch.autograd.detect_anomaly(): + result = cg(h_operator, right_hand_side, tolerance=0, max_iterations=5) + result.abs().sum().backward() + assert right_hand_side.grad is not None From eeb923354ae00a1ee728fc20be8d8495afa0f468 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 2 Nov 2024 19:29:46 +0100 Subject: [PATCH 07/10] Add matrix of LinearOperators (#432) --- src/mrpro/operators/LinearOperator.py | 16 +- src/mrpro/operators/LinearOperatorMatrix.py | 363 +++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_linearoperatormatrix.py | 322 ++++++++++++++++ tests/operators/test_operator_norm.py | 5 +- 5 files changed, 704 insertions(+), 4 deletions(-) create mode 100644 src/mrpro/operators/LinearOperatorMatrix.py create mode 100644 tests/operators/test_linearoperatormatrix.py diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index f136d9d51..d919e63c6 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -102,7 +102,7 @@ def operator_norm( max_iterations: int = 20, relative_tolerance: float = 1e-4, absolute_tolerance: float = 1e-5, - callback: Callable | None = None, + callback: Callable[[torch.Tensor], None] | None = None, ) -> torch.Tensor: """Power iteration for computing the operator norm of the linear operator. @@ -294,6 +294,20 @@ def __rmul__(self, other: torch.Tensor | complex) -> LinearOperator: else: return NotImplemented # type: ignore[unreachable] + def __and__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Vertical stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self], [other]] + return mrpro.operators.LinearOperatorMatrix(operators) + + def __or__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Horizontal stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self, other]] + return mrpro.operators.LinearOperatorMatrix(operators) + @property def gram(self) -> LinearOperator: """Gram operator. diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py new file mode 100644 index 000000000..ab0673398 --- /dev/null +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -0,0 +1,363 @@ +"""Linear Operator Matrix class.""" + +from __future__ import annotations + +import operator +from collections.abc import Callable, Iterator, Sequence +from functools import reduce +from types import EllipsisType +from typing import cast + +import torch +from typing_extensions import Self, Unpack + +from mrpro.operators.LinearOperator import LinearOperator, LinearOperatorSum +from mrpro.operators.Operator import Operator +from mrpro.operators.ZeroOp import ZeroOp + +_SingleIdxType = int | slice | EllipsisType | Sequence[int] +_IdxType = _SingleIdxType | tuple[_SingleIdxType, _SingleIdxType] + + +class LinearOperatorMatrix(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): + r"""Matrix of Linear Operators. + + A matrix of Linear Operators, where each element is a Linear Operator. + + This matrix can be applied to a sequence of tensors, where the number of tensors should match + the number of columns of the matrix. The output will be a sequence of tensors, where the number + of tensors will match the number of rows of the matrix. + The i-th output tensor is calculated as + :math:`\sum_j \text{operators}[i][j](x[j])` where :math:`\text{operators}[i][j]` is the linear operator + in the i-th row and j-th column and :math:`x[j]` is the j-th input tensor. + + The matrix can be indexed and sliced like a regular matrix to get submatrices. + If indexing returns a single element, it is returned as a Linear Operator. + + Basic arithmetic operations are supported with Linear Operators and Tensors. + + """ + + _operators: list[list[LinearOperator]] + + def __init__(self, operators: Sequence[Sequence[LinearOperator]]): + """Initialize Linear Operator Matrix. + + Create a matrix of LinearOperators from a sequence of rows, where each row is a sequence + of LinearOperators that represent the columns of the matrix. + + Parameters + ---------- + operators + A sequence of rows, which are sequences of Linear Operators. + """ + if not all(isinstance(op, LinearOperator) for row in operators for op in row): + raise ValueError('All elements should be LinearOperators.') + if not all(len(row) == len(operators[0]) for row in operators): + raise ValueError('All rows should have the same length.') + super().__init__() + self._operators = cast( # cast because ModuleList is not recognized as a list + list[list[LinearOperator]], torch.nn.ModuleList(torch.nn.ModuleList(row) for row in operators) + ) + self._shape = (len(operators), len(operators[0]) if operators else 0) + + @property + def shape(self) -> tuple[int, int]: + """Shape of the Operator Matrix (rows, columns).""" + return self._shape + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has columns. + + Returns + ------- + Output tensors. The same number of tensors as the operator has rows. + """ + if len(x) != self.shape[1]: + raise ValueError('Input should be the same number of tensors as the LinearOperatorMatrix has columns.') + return tuple( + reduce(operator.add, (op(xi)[0] for op, xi in zip(row, x, strict=True))) for row in self._operators + ) + + def __getitem__(self, idx: _IdxType) -> Self | LinearOperator: + """Index the Operator Matrix. + + Parameters + ---------- + idx + Index or slice to select rows and columns. + + Returns + ------- + Subset LinearOperatorMatrix or Linear Operator. + """ + idxs: tuple[_SingleIdxType, _SingleIdxType] = idx if isinstance(idx, tuple) else (idx, slice(None)) + if len(idxs) > 2: + raise IndexError('Too many indices for LinearOperatorMatrix') + + def _to_numeric_index(idx: slice | int | Sequence[int] | EllipsisType, length: int) -> Sequence[int]: + """Convert index to a sequence of integers or raise an error.""" + if isinstance(idx, slice): + if (idx.start is not None and (idx.start < -length or idx.start >= length)) or ( + idx.stop is not None and (idx.stop < -length or idx.stop > length) + ): + raise IndexError('Index out of range') + return range(*idx.indices(length)) + if isinstance(idx, int): + if idx < -length or idx >= length: + raise IndexError('Index out of range') + return (idx,) + if idx is Ellipsis: + return range(length) + if isinstance(idx, Sequence): + if min(idx) < -length or max(idx) >= length: + raise IndexError('Index out of range') + return idx + else: + raise IndexError('Invalid index type') + + row_numbers = _to_numeric_index(idxs[0], self._shape[0]) + col_numbers = _to_numeric_index(idxs[1], self._shape[1]) + + sliced_operators = [ + [row[col_number] for col_number in col_numbers] + for row in [self._operators[row_number] for row_number in row_numbers] + ] + + # Return a single operator if only one row and column is selected + if len(row_numbers) == 1 and len(col_numbers) == 1: + return sliced_operators[0][0] + else: + return self.__class__(sliced_operators) + + def __iter__(self) -> Iterator[Sequence[LinearOperator]]: + """Iterate over the rows of the Operator Matrix.""" + return iter(self._operators) + + def __repr__(self): + """Representation of the Operator Matrix.""" + return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' + + # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. + def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + """Addition.""" + operators: list[list[LinearOperator]] = [] + if isinstance(other, LinearOperatorMatrix): + if self.shape != other.shape: + raise ValueError('OperatorMatrix shapes do not match.') + for self_row, other_row in zip(self._operators, other._operators, strict=False): + operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) + elif isinstance(other, LinearOperator | torch.Tensor): + if not self.shape[0] == self.shape[1]: + raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') + for i, self_row in enumerate(self._operators): + operators.append([op + other if i == j else op for j, op in enumerate(self_row)]) + else: + return NotImplemented # type: ignore[unreachable] + return self.__class__(operators) + + def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + """Right addition.""" + return self.__add__(other) + + def __mul__(self, other: torch.Tensor | Sequence[torch.Tensor | complex] | complex) -> Self: + """LinearOperatorMatrix*Tensor multiplication. + + Example: ([A,B]*c)(x) = [A*c, B*c](x) = A(c*x) + B(c*x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[1] + elif len(other) != self.shape[1]: + raise ValueError('Other should have the same length as the operator has columns.') + else: + other_ = other + operators = [] + for row in self._operators: + operators.append([op * o for op, o in zip(row, other_, strict=True)]) + return self.__class__(operators) + + def __rmul__(self, other: torch.Tensor | Sequence[torch.Tensor] | complex) -> Self: + """Tensor*LinearOperatorMatrix multiplication. + + Example: (c*[A,B])(x) = [c*A, c*B](x) = c*A(x) + c*B(x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[0] + elif len(other) != self.shape[0]: + raise ValueError('Other should have the same length as the operator has rows.') + else: + other_ = other + operators = [] + for row, o in zip(self._operators, other_, strict=True): + operators.append([cast(LinearOperator, o * op) for op in row]) + return self.__class__(operators) + + def __matmul__(self, other: LinearOperator | Self) -> Self: # type: ignore[override] + """Composition of operators.""" + if isinstance(other, LinearOperator): + return self._binary_operation(other, '__matmul__') + elif isinstance(other, LinearOperatorMatrix): + if self.shape[1] != other.shape[0]: + raise ValueError('OperatorMatrix shapes do not match.') + new_operators = [] + for row in self._operators: + new_row = [] + for other_col in zip(*other._operators, strict=True): + elements = [s @ o for s, o in zip(row, other_col, strict=True)] + new_row.append(LinearOperatorSum(*elements)) + new_operators.append(new_row) + return self.__class__(new_operators) + return NotImplemented # type: ignore[unreachable] + + @property + def H(self) -> Self: # noqa N802 + """Adjoints of the operators.""" + return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the adjoint of the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has rows. + + Returns + ------- + Output tensors. The same number of tensors as the operator has columns. + """ + return self.H(*x) + + @classmethod + def from_diagonal(cls, *operators: LinearOperator): + """Create a diagonal LinearOperatorMatrix. + + Create a square LinearOperatorMatrix with the given Linear Operators on the diagonal, + resulting in a block-diagonal linear operator. + + Parameters + ---------- + operators + Sequence of Linear Operators to be placed on the diagonal. + """ + operator_matrix: list[list[LinearOperator]] = [ + [op if i == j else ZeroOp(False) for j in range(len(operators))] for i, op in enumerate(operators) + ] + return cls(operator_matrix) + + def operator_norm( + self, + *initial_value: torch.Tensor, + dim: Sequence[int] | None = None, + max_iterations: int = 20, + relative_tolerance: float = 1e-4, + absolute_tolerance: float = 1e-5, + callback: Callable[[torch.Tensor], None] | None = None, + ) -> torch.Tensor: + """Upper bound of operator norm of the Matrix. + + Uses the bounds :math:`||[A, B]^T|||<=sqrt(||A||^2 + ||B||^2)` and :math:`||[A, B]|||<=max(||A||,||B||)` + to estimate the operator norm of the matrix. + First, operator_norm is called on each element of the matrix. + Next, the norm is estimated for each column using the first bound. + Finally, the norm of the full matrix of linear operators is calculated using the second bound. + + Parameters + ---------- + initial_value + Initial value(s) for the power iteration, length should match the number of columns + of the operator matrix. + dim + Dimensions to calculate the operator norm over. Other dimensions are assumed to be + batch dimensions. None means all dimensions. + max_iterations + Maximum number of iterations used in the power iteration. + relative_tolerance + Relative tolerance for convergence. + absolute_tolerance + Absolute tolerance for convergence. + callback + Callback function to be called with the current estimate of the operator norm. + + + Returns + ------- + Estimated operator norm upper bound. + """ + + def _singlenorm(op: LinearOperator, initial_value: torch.Tensor): + return op.operator_norm( + initial_value, + dim=dim, + max_iterations=max_iterations, + relative_tolerance=relative_tolerance, + absolute_tolerance=absolute_tolerance, + callback=callback, + ) + + if len(initial_value) != self.shape[1]: + raise ValueError('Initial value should have the same length as the operator has columns.') + norms = torch.tensor( + [[_singlenorm(op, iv) for op, iv in zip(row, initial_value, strict=True)] for row in self._operators] + ) + norm = norms.square().sum(-2).sqrt().amax(-1).unsqueeze(-1) + return norm + + def __or__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Vertical stacking.""" + if isinstance(other, LinearOperator): + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking : cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[*self._operators[0], other]] + return self.__class__(operators) + else: + if (rows_self := self.shape[0]) != (rows_other := other.shape[0]): + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack matrices with {rows_self} and {rows_other}.' + ) + operators = [[*self_row, *other_row] for self_row, other_row in zip(self, other, strict=True)] + return self.__class__(operators) + + def __ror__(self, other: LinearOperator) -> Self: + """Vertical stacking.""" + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[other, *self._operators[0]]] + return self.__class__(operators) + + def __and__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Horizontal stacking.""" + if isinstance(other, LinearOperator): + if cols := self.shape[1] > 1: + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [*self._operators, [other]] + return self.__class__(operators) + else: + if (cols_self := self.shape[1]) != (cols_other := other.shape[1]): + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack matrices with {cols_self} and {cols_other} columns.' + ) + operators = [*self._operators, *other] + return self.__class__(operators) + + def __rand__(self, other: LinearOperator) -> Self: + """Horizontal stacking.""" + if cols := self.shape[1] > 1: + raise ValueError( + f'Shape mismatch in horizontal stacking: cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [[other], *self._operators] + return self.__class__(operators) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 345c337fa..b8c16ebfe 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -11,6 +11,7 @@ from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp from mrpro.operators.PhaseOp import PhaseOp @@ -36,6 +37,7 @@ "GridSamplingOp", "IdentityOp", "LinearOperator", + "LinearOperatorMatrix", "MagnitudeOp", "MultiIdentityOp", "Operator", diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py new file mode 100644 index 000000000..7ba87d715 --- /dev/null +++ b/tests/operators/test_linearoperatormatrix.py @@ -0,0 +1,322 @@ +from typing import Any + +import pytest +import torch +from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +def random_linearop(size, rng): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex64_tensor(size), '... i j, ... j -> ... i') + + +def random_linearoperatormatrix(size, inner_size, rng): + """Create a random LinearOperatorMatrix.""" + operators = [[random_linearop(inner_size, rng) for i in range(size[1])] for j in range(size[0])] + return LinearOperatorMatrix(operators) + + +def test_linearoperatormatrix_shape(): + """Test creation and shape of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert matrix.shape == (5, 3) + + +def test_linearoperatormatrix_add_matrix(): + """Test addition of two LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 + matrix2)(*vector) + expected = tuple(a + b for a, b in zip(matrix1(*vector), matrix2(*vector), strict=False)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_add_tensor_nonsquare(): + """Test failure of addition of tensor to non-square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + with pytest.raises(NotImplementedError, match='square'): + (matrix1 + other) + + +def test_linearoperatormatrix_add_tensor_square(): + """Add tensor to square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 3), (2, 2), rng) + other = rng.complex64_tensor(2) + vector = rng.complex64_tensor((3, 2)) + result = (matrix1 + other)(*vector) + expected = tuple((mv + other * v for mv, v in zip(matrix1(*vector), vector, strict=True))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_rmul(): + """Test post multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + vector = rng.complex64_tensor((3, 10)) + result = (other * matrix1)(*vector) + expected = tuple(other * el for el in matrix1(*vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_mul(): + """Test pre multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(10) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 * other)(*vector) + expected = matrix1(*(other * el for el in vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition(): + """Test composition of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 @ matrix2)(*vector) + expected = matrix1(*(matrix2(*vector))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition_mismatch(): + """Test composition with mismatching shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((4, 3), (3, 10), rng) + vector = rng.complex64_tensor((4, 10)) + with pytest.raises(ValueError, match='shapes do not match'): + (matrix1 @ matrix2)(*vector) + + +def test_linearoperatormatrix_adjoint(): + """Test adjointness of Adjoint.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + dotproduct_adjointness_test(Wrapper(), rng.complex64_tensor((3, 10)), rng.complex64_tensor((5, 3))) + + +def test_linearoperatormatrix_repr(): + """Test repr of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert 'LinearOperatorMatrix(shape=(5, 3)' in repr(matrix) + + +def test_linearoperatormatrix_getitem(): + """Test slicing of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + def check(actual, expected): + assert tuple(tuple(row) for row in actual) == tuple(tuple(row) for row in expected) + + sliced = matrix[1:3, 2] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[2:3] for row in matrix._operators[1:3]]) + + sliced = matrix[0] + assert sliced.shape == (1, 6) + check(sliced._operators, matrix._operators[:1]) + + sliced = matrix[..., 0] + assert sliced.shape == (12, 1) + check(sliced._operators, [row[:1] for row in matrix._operators]) + + sliced = matrix[1:6:2, (3, 4)] + assert sliced.shape == (3, 2) + check(sliced._operators, [[matrix._operators[i][j] for j in (3, 4)] for i in range(1, 6, 2)]) + + sliced = matrix[-2:-4:-1, -1] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[-1:] for row in matrix._operators[-2:-4:-1]]) + + sliced = matrix[5, 5] + assert isinstance(sliced, LinearOperator) + assert sliced == matrix._operators[5][5] + + +def test_linearoperatormatrix_getitem_error(): + """Test error when slicing with wrong indices.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + with pytest.raises(IndexError, match='Too many indices'): + matrix[1, 1, 1] + + with pytest.raises(IndexError, match='out of range'): + matrix[20] + with pytest.raises(IndexError, match='out of range'): + matrix[-20] + with pytest.raises(IndexError, match='out of range'): + matrix[1:100] + with pytest.raises(IndexError, match='out of range'): + matrix[(100, 1)] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., -20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 1:100] + with pytest.raises(IndexError, match='index type'): + matrix[..., 1.0] + + +def test_linearoperatormatrix_norm_rows(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((3, 1), (3, 10), rng) + vector = rng.complex64_tensor((1, 10)) + result = matrix.operator_norm(*vector) + expected = sum(row[0].operator_norm(vector[0], dim=None) ** 2 for row in matrix._operators) ** 0.5 + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_norm_cols(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((1, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = matrix.operator_norm(*vector) + expected = max(op.operator_norm(v, dim=None) for op, v in zip(matrix._operators[0], vector, strict=False)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parametrize('seed', [0, 1, 2, 3]) +def test_linearoperatormatrix_norm(seed): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(seed) + matrix = random_linearoperatormatrix((4, 2), (3, 10), rng) + vector = rng.complex64_tensor((2, 10)) + result = matrix.operator_norm(*vector) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + real = Wrapper().operator_norm(vector, dim=None) + + assert result >= real + + +def test_linearoperatormatrix_shorthand_vertical(): + """Test shorthand for vertical stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 10), rng) + op2 = random_linearop((4, 10), rng) + x1 = rng.complex64_tensor((10,)) + + matrix1 = op1 & op2 + assert matrix1.shape == (2, 1) + + actual = matrix1(x1) + expected = (*op1(x1), *op2(x1)) + torch.testing.assert_close(actual, expected) + + matrix2 = op2 & (matrix1 & op1) + assert matrix2.shape == (4, 1) + + matrix3 = matrix2 & matrix2 + assert matrix3.shape == (8, 1) + + actual = matrix3(x1) + expected = 2 * (*op2(x1), *matrix1(x1), *op1(x1)) + torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_shorthand_horizontal(): + """Test shorthand for horizontal stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 4), rng) + op2 = random_linearop((3, 2), rng) + x1 = rng.complex64_tensor((4,)) + x2 = rng.complex64_tensor((2,)) + x3 = rng.complex64_tensor((4,)) + x4 = rng.complex64_tensor((2,)) + + matrix1 = op1 | op2 + assert matrix1.shape == (1, 2) + + actual1 = matrix1(x1, x2) + expected1 = (op1(x1)[0] + op2(x2)[0],) + torch.testing.assert_close(actual1, expected1) + + matrix2 = op2 | (matrix1 | op1) + assert matrix2.shape == (1, 4) + + matrix3 = matrix2 | matrix2 + assert matrix3.shape == (1, 8) + + expected3 = (2 * (op2(x2)[0] + (matrix1(x3, x4)[0] + op1(x1)[0])),) + actual3 = matrix3(x2, x3, x4, x1, x2, x3, x4, x1) + torch.testing.assert_close(actual3, expected3) + + +def test_linearoperatormatrix_stacking_error(): + """Test error when stacking matrix operators with different shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 4), (3, 10), rng) + matrix2 = random_linearoperatormatrix((3, 2), (3, 10), rng) + matrix3 = random_linearoperatormatrix((2, 4), (3, 10), rng) + op = random_linearop((3, 10), rng) + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & matrix2 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | matrix3 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | op + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & op + + +def test_linearoperatormatrix_error_nonlinearop(): + """Test error if trying to create a LinearOperatorMatrix with non linear operator.""" + op: Any = [[MagnitudeOp()]] # Any is used to hide this error from mypy + with pytest.raises(ValueError, match='LinearOperator'): + LinearOperatorMatrix(op) + + +def test_linearoperatormatrix_error_inconsistent_shapes(): + """Test error if trying to create a LinearOperatorMatrix with inonsistent row lengths.""" + rng = RandomGenerator(0) + op = random_linearop((3, 4), rng) + with pytest.raises(ValueError, match='same length'): + LinearOperatorMatrix([[op, op], [op]]) + + +def test_linearoperatormatrix_from_diagonal(): + """Test creation of LinearOperatorMatrix from diagonal.""" + rng = RandomGenerator(0) + ops = [random_linearop((2, 4), rng) for _ in range(3)] + matrix = LinearOperatorMatrix.from_diagonal(*ops) + xs = rng.complex64_tensor((3, 4)) + actual = matrix(*xs) + expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) + torch.testing.assert_close(actual, expected) diff --git a/tests/operators/test_operator_norm.py b/tests/operators/test_operator_norm.py index d5ae39873..c6e13dcf9 100644 --- a/tests/operators/test_operator_norm.py +++ b/tests/operators/test_operator_norm.py @@ -12,9 +12,8 @@ def test_power_iteration_uses_stopping_criterion(): """Test if the power iteration stops if the absolute and relative tolerance are chosen high.""" - # callback function that should not be called because the power iteration - # should stop if the tolerances are set high - def callback(): + def callback(_): + """Callback function that should not be called, because the power iteration should stop.""" pytest.fail('The power iteration did not stop despite high atol and rtol!') random_generator = RandomGenerator(seed=0) From bc830f5c345bfcc123a627e31395c77060854787 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:14:13 +0100 Subject: [PATCH 08/10] [pre-commit] pre-commit autoupdate (#498) Co-authored-by: Patrick Schuenke --- .pre-commit-config.yaml | 8 +++---- examples/direct_reconstruction.ipynb | 8 +++---- examples/direct_reconstruction.py | 8 +++---- examples/iterative_sense_reconstruction.ipynb | 8 +++---- examples/iterative_sense_reconstruction.py | 8 +++---- examples/pulseq_2d_radial_golden_angle.ipynb | 22 ++++++++++--------- examples/pulseq_2d_radial_golden_angle.py | 17 +++++++------- ...rized_iterative_sense_reconstruction.ipynb | 8 +++---- ...ularized_iterative_sense_reconstruction.py | 8 +++---- 9 files changed, 49 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7803441eb..95d14317a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-docstring-first @@ -15,14 +15,14 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.25.0 + rev: v1.27.0 hooks: - id: typos @@ -34,7 +34,7 @@ repos: exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy pass_filenames: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 3b6dc930e..1e4e74c9c 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 7672aa7e7..5d55812c9 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index 87249b2fb..f612d7522 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index 6d0bc49a5..ba5e6a01a 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index 52e0310bb..bcb4482a1 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,13 +33,14 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" ] }, { @@ -50,9 +51,10 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { @@ -125,10 +127,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", - "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - "seq_file.write(response.content)\n", - "seq_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", + " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + " seq_file.write(response.content)\n", + " seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 3f857c382..f4db5217a 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,13 +19,14 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') + # %% # Download raw data using requests -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -62,10 +63,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') -response = requests.get(zenodo_url + seq_fname, timeout=30) -seq_file.write(response.content) -seq_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: + response = requests.get(zenodo_url + seq_fname, timeout=30) + seq_file.write(response.content) + seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb index 6b1c2704b..0a6743161 100644 --- a/examples/regularized_iterative_sense_reconstruction.ipynb +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py index e41dc4ac5..2ab7ba033 100644 --- a/examples/regularized_iterative_sense_reconstruction.py +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction From dcb5e3407a32e2cbd9085ffbe4ce9430f8bacad0 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 9 Nov 2024 15:26:16 +0100 Subject: [PATCH 09/10] Add PCA-based compression operator (#181) Co-authored-by: Christoph Kolbitsch --- src/mrpro/operators/PCACompressionOp.py | 85 ++++++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_pca_compression_op.py | 50 +++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/mrpro/operators/PCACompressionOp.py create mode 100644 tests/operators/test_pca_compression_op.py diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index b8c16ebfe..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -14,6 +14,7 @@ from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -41,6 +42,7 @@ "MagnitudeOp", "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..a600bcecb --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,50 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) From c268ad25a429f93d3bb33df6cb2bf0ffc06c4759 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Sat, 9 Nov 2024 16:11:57 +0100 Subject: [PATCH 10/10] Fix CartesianSamplingOp (#483) --- src/mrpro/operators/CartesianSamplingOp.py | 10 ++++++---- tests/operators/test_cartesian_sampling_op.py | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..7a51924b1 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -47,26 +47,28 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') self.register_buffer('_fft_idx', kidx) # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..7caa13e7c 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,6 +2,7 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match(): 'regular_undersampling', 'random_undersampling', 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', ], ) def test_cart_sampling_op_fwd_adj(sampling): @@ -70,8 +74,8 @@ def test_cart_sampling_op_fwd_adj(sampling): nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) sx = 'uf' - sy = 'uf' - sz = 'uf' + sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' + sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() # Subsample data and trajectory @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling): for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = torch.randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.')