diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index 9f02e8dcc..a55b7860d 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -1,4 +1,4 @@ -name: Publish to PyPI +name: Build and Deploy Package on: push: @@ -7,21 +7,31 @@ on: jobs: build-testpypi-package: - name: Modify version and build package for TestPyPI + name: Build Package for TestPyPI runs-on: ubuntu-latest outputs: version: ${{ steps.set_suffix.outputs.version }} suffix: ${{ steps.set_suffix.outputs.suffix }} + version_changed: ${{ steps.changes.outputs.version_changed }} steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.x" - - name: Set environment variable for version suffix + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Check if VERSION file is modified compared to main + uses: dorny/paths-filter@v3 + id: changes + with: + base: main + filters: | + version_changed: + - 'src/mrpro/VERSION' + + - name: Set Version Suffix id: set_suffix run: | VERSION=$(cat src/mrpro/VERSION) @@ -30,12 +40,12 @@ jobs: echo "suffix=$SUFFIX" >> $GITHUB_OUTPUT echo "version=$VERSION" >> $GITHUB_OUTPUT - - name: Build package with version suffix + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store TestPyPi distribution + - name: Upload TestPyPI Distribution Artifact uses: actions/upload-artifact@v4 with: name: testpypi-package-distribution @@ -46,6 +56,7 @@ jobs: needs: - build-testpypi-package runs-on: ubuntu-latest + if: needs.build-testpypi-package.outputs.version_changed == 'true' environment: name: testpypi @@ -55,7 +66,7 @@ jobs: id-token: write steps: - - name: Download TestPyPi distribution + - name: Download TestPyPI Distribution uses: actions/download-artifact@v4 with: name: testpypi-package-distribution @@ -68,7 +79,7 @@ jobs: verbose: true test-install-from-testpypi: - name: Test installation from TestPyPI + name: Test Installation from TestPyPI needs: - testpypi-deployment - build-testpypi-package @@ -83,17 +94,25 @@ jobs: run: | VERSION=${{ needs.build-testpypi-package.outputs.version }} SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }} - python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ + for i in {1..3}; do + if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then + echo "Package installed successfully." + break + else + echo "Attempt $i failed. Retrying in 10 seconds..." + sleep 10 + fi + done build-pypi-package: - name: Build package for PyPI + name: Build Package for PyPI runs-on: ubuntu-latest needs: - test-install-from-testpypi outputs: version: ${{ steps.get_version.outputs.version }} steps: - - name: Checkout code + - name: Checkout Code uses: actions/checkout@v4 - name: Set up Python @@ -101,22 +120,22 @@ jobs: with: python-version: "3.x" - - name: Install setuptools_git_versioning + - name: Install Automatic Versioning Tool run: | python -m pip install setuptools-git-versioning - - name: Get current version + - name: Get Current Version id: get_version run: | VERSION=$(python -m setuptools_git_versioning) echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - - name: Build package + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store PyPi distribution + - name: Store PyPI Distribution uses: actions/upload-artifact@v4 with: name: pypi-package-distribution @@ -138,13 +157,13 @@ jobs: id-token: write steps: - - name: Download PyPi distribution + - name: Download PyPI Distribution uses: actions/download-artifact@v4 with: name: pypi-package-distribution path: dist/ - - name: Create tag + - name: Create Tag uses: actions/github-script@v7 with: script: | @@ -155,7 +174,7 @@ jobs: sha: context.sha }) - - name: Create release + - name: Create Release uses: actions/github-script@v7 with: github-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 427b62e96..53d1343c7 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -2,17 +2,21 @@ name: PyTest on: pull_request: + push: + branches: + - main jobs: get_dockerfiles: - name: Get list of dockerfiles for different containers + name: Get List of Dockerfiles for Containers runs-on: ubuntu-latest permissions: packages: read outputs: imagenames: ${{ steps.set-matrix.outputs.imagenames }} steps: - - id: set-matrix + - name: Retrieve Docker Image Names + id: set-matrix env: GH_TOKEN: ${{ secrets.GHCR_TOKEN }} run: | @@ -35,58 +39,56 @@ jobs: echo "image names with tag latest: $imagenames_latest" echo "imagenames=$imagenames_latest" >> $GITHUB_OUTPUT - - name: Dockerfile overview + - name: Dockerfile Overview run: | - echo "final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" + echo "Final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" test: - name: Run tests and get coverage report + name: Run Tests and Coverage Report needs: get_dockerfiles runs-on: ubuntu-latest permissions: pull-requests: write contents: write strategy: - fail-fast: false + fail-fast: false matrix: imagename: ${{ fromJson(needs.get_dockerfiles.outputs.imagenames) }} - # runs within Docker container container: image: ghcr.io/ptb-mr/${{ matrix.imagename }}:latest options: --user runner - steps: - - name: Checkout repo + - name: Checkout Repository uses: actions/checkout@v4 - - name: Install mrpro and dependencies - run: pip install --upgrade --upgrade-strategy "eager" .[test] + - name: Install MRpro and Dependencies + run: pip install --upgrade --upgrade-strategy eager .[test] - - name: Install pytest-github-actions-annotate-failures plugin + - name: Install PyTest GitHub Annotation Plugin run: pip install pytest-github-actions-annotate-failures - - name: Run PyTest + - name: Run PyTest and Generate Coverage Report run: | - pytest -n 4 -m "not cuda" --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt + pytest -n 4 -m "not cuda" --junitxml=pytest.xml \ + --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt - - name: Check for pytest.xml + - name: Verify PyTest XML Output run: | - if [ -f pytest.xml ]; then - echo "pytest.xml file found. Continuing..." - else - echo "pytest.xml file not found. Please check previous 'Run PyTest' section for errors." + if [ ! -f pytest.xml ]; then + echo "PyTest XML report not found. Please check the previous 'Run PyTest' step for errors." exit 1 fi - - name: Pytest coverage comment + - name: Post PyTest Coverage Comment id: coverageComment uses: MishaKav/pytest-coverage-comment@v1.1.53 with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml - - name: Create the Badge + - name: Create Coverage Badge on Main Branch Push uses: schneegans/dynamic-badges-action@v1.7.0 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' with: auth: ${{ secrets.GIST_SECRET }} gistID: 48e334a10caf60e6708d7c712e56d241 @@ -96,12 +98,12 @@ jobs: color: ${{ steps.coverageComment.outputs.color }} namedLogo: python - - name: Set pipeline status + - name: Set Pipeline Status Based on Test Results if: steps.coverageComment.outputs.errors != 0 || steps.coverageComment.outputs.failures != 0 uses: actions/github-script@v7 with: script: | - core.setFailed('PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.') + core.setFailed("PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.") concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b14c4344b..303fd43fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,10 +3,9 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - - id: check-docstring-first - id: check-merge-conflict - id: check-yaml - id: check-toml @@ -15,19 +14,26 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.25.0 + rev: v1.27.0 hooks: - id: typos + - repo: https://github.com/fzimmermann89/check_all + rev: v1.0 + hooks: + - id: check-init-all + args: [--double-quotes] + exclude: ^tests/ + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy pass_filenames: false @@ -45,12 +51,13 @@ repos: - xsdata - "--index-url=https://download.pytorch.org/whl/cpu" - "--extra-index-url=https://pypi.python.org/simple" + ci: - autofix_commit_msg: | - [pre-commit] auto fixes from pre-commit hooks - autofix_prs: false - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [mypy] - submodules: false + autofix_commit_msg: | + [pre-commit] auto fixes from pre-commit hooks + autofix_prs: false + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate" + autoupdate_schedule: monthly + skip: [mypy] + submodules: false diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index d0e6e0adb..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -76,7 +76,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_ismrmrd_traj = direct_reconstruction.forward(kdata)" + "img_using_ismrmrd_traj = direct_reconstruction(kdata)" ] }, { @@ -100,7 +100,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_rad2d_traj = direct_reconstruction.forward(kdata)" + "img_using_rad2d_traj = direct_reconstruction(kdata)" ] }, { @@ -143,7 +143,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_pulseq_traj = direct_reconstruction.forward(kdata)" + "img_using_pulseq_traj = direct_reconstruction(kdata)" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 955b20a37..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -37,7 +37,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_ismrmrd_traj = direct_reconstruction.forward(kdata) +img_using_ismrmrd_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryRadial2D @@ -49,7 +49,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_rad2d_traj = direct_reconstruction.forward(kdata) +img_using_rad2d_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryPulseq @@ -73,7 +73,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_pulseq_traj = direct_reconstruction.forward(kdata) +img_using_pulseq_traj = direct_reconstruction(kdata) # %% [markdown] # ### Plot the different reconstructed images diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb new file mode 100644 index 000000000..6b1c2704b --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af432293", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Regularized Iterative SENSE Reconstruction of 2D golden angle radial data\n", + "Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7a6ce3", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# define zenodo URL of the example ismrmd data\n", + "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd8486b", + "metadata": {}, + "outputs": [], + "source": [ + "# Download raw data\n", + "import tempfile\n", + "\n", + "import requests\n", + "\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "6a9defa1", + "metadata": {}, + "source": [ + "### Image reconstruction\n", + "We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data.\n", + "RegularizedIterativeSENSEReconstruction solves the following reconstruction problem:\n", + "\n", + "Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms,\n", + "coil sensitivity maps...) $A$ then we can formulate the forward problem as:\n", + "\n", + "$ y = Ax + n $\n", + "\n", + "where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 $\n", + "\n", + "where $W^\\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal\n", + "operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain\n", + "a solution with certain properties:\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$\n", + "\n", + "where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image.\n", + "With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$.\n", + "\n", + "Setting the derivative of the functional $F$ to zero and rearranging yields\n", + "\n", + "$ (A^H W A + l B) x = A^H W y + l x_{reg}$\n", + "\n", + "which is a linear system $Hx = b$ that needs to be solved for $x$.\n", + "\n", + "One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution\n", + "dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks\n", + "has been used as an image regulariser.\n", + "\n", + "In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image.\n", + "Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using\n", + "only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the\n", + "regularization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4da15c2", + "metadata": {}, + "outputs": [], + "source": [ + "import mrpro" + ] + }, + { + "cell_type": "markdown", + "id": "de055070", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Read-in the raw data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ac1d89f", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.data import KData\n", + "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", + "\n", + "# Load in the Data and the trajectory from the ISMRMRD file\n", + "kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd())\n", + "kdata.header.recon_matrix.x = 256\n", + "kdata.header.recon_matrix.y = 256" + ] + }, + { + "cell_type": "markdown", + "id": "1f389140", + "metadata": {}, + "source": [ + "##### Image $x_{reg}$ from fully sampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212b915c", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction\n", + "from mrpro.data import CsmData\n", + "\n", + "# Estimate coil maps\n", + "direct_reconstruction = DirectReconstruction(kdata, csm=None)\n", + "img_coilwise = direct_reconstruction(kdata)\n", + "csm = CsmData.from_idata_walsh(img_coilwise)\n", + "\n", + "# Iterative SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3)\n", + "img_iterative_sense = iterative_sense_reconstruction(kdata)" + ] + }, + { + "cell_type": "markdown", + "id": "bec6b712", + "metadata": {}, + "source": [ + "##### Image $x$ from undersampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6740447", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Data undersampling, i.e. take only the first 20 radial lines\n", + "idx_us = torch.arange(0, 20)[None, :]\n", + "kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fbfd664", + "metadata": {}, + "outputs": [], + "source": [ + "# Iterativ SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6)\n", + "img_us_iterative_sense = iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041ffe72", + "metadata": {}, + "outputs": [], + "source": [ + "# Regularized iterativ SENSE reconstruction\n", + "from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction\n", + "\n", + "regularization_weight = 1.0\n", + "n_iterations = 6\n", + "regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction(\n", + " kdata_us,\n", + " csm=csm,\n", + " n_iterations=n_iterations,\n", + " regularization_data=img_iterative_sense.data,\n", + " regularization_weight=regularization_weight,\n", + ")\n", + "img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5bbec1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()]\n", + "vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4))\n", + "for ind in range(3):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "2bd49c87", + "metadata": {}, + "source": [ + "### Behind the scenes" + ] + }, + { + "cell_type": "markdown", + "id": "53779251", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", + "\n", + "This is very similar to the iterative SENSE reconstruction. For more detail please look at the\n", + "iterative_sense_reconstruction notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e985a4f3", + "metadata": {}, + "outputs": [], + "source": [ + "dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator()\n", + "fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us)\n", + "csm_operator = csm.as_operator()\n", + "acquisition_operator = fourier_operator @ csm_operator" + ] + }, + { + "cell_type": "markdown", + "id": "2daa0fee", + "metadata": {}, + "source": [ + "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac1d5fb4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "right_hand_side = (\n", + " acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76a0b153", + "metadata": {}, + "source": [ + "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5effb592", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.operators import IdentityOp\n", + "\n", + "operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor(\n", + " regularization_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f24a8588", + "metadata": {}, + "source": [ + "##### Run conjugate gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96827838", + "metadata": {}, + "outputs": [], + "source": [ + "img_manual = mrpro.algorithms.optimizers.cg(\n", + " operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c065a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the reconstructed image\n", + "vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]]\n", + "vis_title = ['Regularized Iterative SENSE R=20', '\"Manual\" Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4))\n", + "for ind in range(2):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "d6d7efdf", + "metadata": {}, + "source": [ + "### Check for equal results\n", + "The two versions should result in the same image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f59b6015", + "metadata": {}, + "outputs": [], + "source": [ + "# If the assert statement did not raise an exception, the results are equal.\n", + "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" + ] + }, + { + "cell_type": "markdown", + "id": "6ecd6e70", + "metadata": {}, + "source": [ + "### Next steps\n", + "Play around with the regularization_weight to see how it effects the final image quality.\n", + "\n", + "Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications\n", + "we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the\n", + "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", + "compared to the unregularised images." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py new file mode 100644 index 000000000..e41dc4ac5 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -0,0 +1,193 @@ +# %% [markdown] +# # Regularized Iterative SENSE Reconstruction of 2D golden angle radial data +# Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data +# %% +# define zenodo URL of the example ismrmd data +zenodo_url = 'https://zenodo.org/records/10854057/files/' +fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' +# %% +# Download raw data +import tempfile + +import requests + +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() + +# %% [markdown] +# ### Image reconstruction +# We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data. +# RegularizedIterativeSENSEReconstruction solves the following reconstruction problem: +# +# Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms, +# coil sensitivity maps...) $A$ then we can formulate the forward problem as: +# +# $ y = Ax + n $ +# +# where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$ +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 $ +# +# where $W^\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal +# operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain +# a solution with certain properties: +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$ +# +# where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image. +# With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$. +# +# Setting the derivative of the functional $F$ to zero and rearranging yields +# +# $ (A^H W A + l B) x = A^H W y + l x_{reg}$ +# +# which is a linear system $Hx = b$ that needs to be solved for $x$. +# +# One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution +# dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks +# has been used as an image regulariser. +# +# In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image. +# Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using +# only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the +# regularization. + +# %% +import mrpro + +# %% [markdown] +# ##### Read-in the raw data +# %% +from mrpro.data import KData +from mrpro.data.traj_calculators import KTrajectoryIsmrmrd + +# Load in the Data and the trajectory from the ISMRMRD file +kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd()) +kdata.header.recon_matrix.x = 256 +kdata.header.recon_matrix.y = 256 + +# %% [markdown] +# ##### Image $x_{reg}$ from fully sampled data + +# %% +from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction +from mrpro.data import CsmData + +# Estimate coil maps +direct_reconstruction = DirectReconstruction(kdata, csm=None) +img_coilwise = direct_reconstruction(kdata) +csm = CsmData.from_idata_walsh(img_coilwise) + +# Iterative SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3) +img_iterative_sense = iterative_sense_reconstruction(kdata) + +# %% [markdown] +# ##### Image $x$ from undersampled data + +# %% +import torch + +# Data undersampling, i.e. take only the first 20 radial lines +idx_us = torch.arange(0, 20)[None, :] +kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition') + +# %% +# Iterativ SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6) +img_us_iterative_sense = iterative_sense_reconstruction(kdata_us) + +# %% +# Regularized iterativ SENSE reconstruction +from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction + +regularization_weight = 1.0 +n_iterations = 6 +regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction( + kdata_us, + csm=csm, + n_iterations=n_iterations, + regularization_data=img_iterative_sense.data, + regularization_weight=regularization_weight, +) +img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us) + +# %% +import matplotlib.pyplot as plt + +vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()] +vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4)) +for ind in range(3): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + + +# %% [markdown] +# ### Behind the scenes + +# %% [markdown] +# ##### Set-up the density compensation operator $W$ and acquisition model $A$ +# +# This is very similar to the iterative SENSE reconstruction. For more detail please look at the +# iterative_sense_reconstruction notebook. +# %% +dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator() +fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us) +csm_operator = csm.as_operator() +acquisition_operator = fourier_operator @ csm_operator + +# %% [markdown] +# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ + +# %% +right_hand_side = ( + acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data +) + + +# %% [markdown] +# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ + +# %% +from mrpro.operators import IdentityOp + +operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor( + regularization_weight +) + +# %% [markdown] +# ##### Run conjugate gradient + +# %% +img_manual = mrpro.algorithms.optimizers.cg( + operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0 +) + +# %% +# Display the reconstructed image +vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]] +vis_title = ['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4)) +for ind in range(2): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + +# %% [markdown] +# ### Check for equal results +# The two versions should result in the same image data. + +# %% +# If the assert statement did not raise an exception, the results are equal. +assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) + +# %% [markdown] +# ### Next steps +# Play around with the regularization_weight to see how it effects the final image quality. +# +# Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications +# we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the +# streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality +# compared to the unregularised images. diff --git a/examples/ruff.toml b/examples/ruff.toml index 11a1e6167..1bb114755 100644 --- a/examples/ruff.toml +++ b/examples/ruff.toml @@ -5,4 +5,5 @@ lint.extend-ignore = [ "T20", #print "E402", #module-import-not-at-top-of-file "S101", #assert + "SIM115", #context manager for opening files ] diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0a67c464f..c60039027 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241015 +0.241112 diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 4a0ac4ca0..729ae188c 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,4 +1,10 @@ from mrpro._version import __version__ from mrpro import algorithms, operators, data, phantoms, utils -__all__ = ["algorithms", "operators", "data", "phantoms", "utils"] - +__all__ = [ + "__version__", + "algorithms", + "data", + "operators", + "phantoms", + "utils" +] diff --git a/src/mrpro/algorithms/__init__.py b/src/mrpro/algorithms/__init__.py index 6a5d8cd64..3dbc8ed07 100644 --- a/src/mrpro/algorithms/__init__.py +++ b/src/mrpro/algorithms/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms import csm, optimizers, reconstruction from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -__all__ = ["csm", "optimizers", "reconstruction", "prewhiten_kspace"] +__all__ = ["csm", "optimizers", "prewhiten_kspace", "reconstruction"] \ No newline at end of file diff --git a/src/mrpro/algorithms/csm/__init__.py b/src/mrpro/algorithms/csm/__init__.py index 499984f53..255e09a89 100644 --- a/src/mrpro/algorithms/csm/__init__.py +++ b/src/mrpro/algorithms/csm/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms.csm.walsh import walsh from mrpro.algorithms.csm.inati import inati -__all__ = ["walsh", "inati"] +__all__ = ["inati", "walsh"] \ No newline at end of file diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 54cb1d782..b879796b0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -81,11 +81,6 @@ def cg( if torch.vdot(residual.flatten(), residual.flatten()) == 0: return solution - # squared tolerance; - # (we will check ||residual||^2 < tolerance^2 instead of ||residual|| < tol - # to avoid the computation of the root for the norm) - tolerance_squared = tolerance**2 - # dummy value. new value will be set in loop before first usage residual_norm_squared_previous = None @@ -95,7 +90,7 @@ def cg( residual_norm_squared = torch.vdot(residual_flat, residual_flat).real # check if the solution is already accurate enough - if tolerance != 0 and (residual_norm_squared < tolerance_squared): + if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution if iteration > 0: @@ -105,8 +100,8 @@ def cg( # update estimates of the solution and the residual (operator_conjugate_vector,) = operator(conjugate_vector) alpha = residual_norm_squared / (torch.vdot(conjugate_vector.flatten(), operator_conjugate_vector.flatten())) - solution += alpha * conjugate_vector - residual -= alpha * operator_conjugate_vector + solution = solution + alpha * conjugate_vector + residual = residual - alpha * operator_conjugate_vector residual_norm_squared_previous = residual_norm_squared diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 04d37128c..32785a91a 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -4,20 +4,17 @@ from collections.abc import Callable -import torch - -from mrpro.algorithms.optimizers.cg import cg -from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( + RegularizedIterativeSENSEReconstruction, +) from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData -from mrpro.data.IData import IData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator -class IterativeSENSEReconstruction(DirectReconstruction): +class IterativeSENSEReconstruction(RegularizedIterativeSENSEReconstruction): r"""Iterative SENSE reconstruction. This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2` @@ -49,6 +46,8 @@ def __init__( ) -> None: """Initialize IterativeSENSEReconstruction. + For a regularized version of the iterative SENSE algorithm please see RegularizedIterativeSENSEReconstruction. + Parameters ---------- kdata @@ -74,76 +73,4 @@ def __init__( ValueError If the kdata and fourier_op are None or if csm is a Callable but kdata is None. """ - super().__init__(kdata, fourier_op, csm, noise, dcf) - self.n_iterations = n_iterations - - def _self_adjoint_operator(self) -> LinearOperator: - """Create the self-adjoint operator. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Create the self-adjoint operator as :math:`H = A^H W A` if the DCF is not None otherwise as :math:`H = A^H A`. - """ - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Create H = A^H W A - operator = operator.H @ dcf_operator @ operator - else: - # Create H = A^H A - operator = operator.H @ operator - - return operator - - def _right_hand_side(self, kdata: KData) -> torch.Tensor: - """Calculate the right-hand-side of the normal equation. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Calculate the right-hand-side as :math:`b = A^H W y` if the DCF is not None otherwise as :math:`b = A^H y`. - - Parameters - ---------- - kdata - k-space data to reconstruct. - """ - device = kdata.data.device - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Calculate b = A^H W y - (right_hand_side,) = operator.to(device).H(dcf_operator(kdata.data)[0]) - else: - # Calculate b = A^H y - (right_hand_side,) = operator.to(device).H(kdata.data) - - return right_hand_side - - def forward(self, kdata: KData) -> IData: - """Apply the reconstruction. - - Parameters - ---------- - kdata - k-space data to reconstruct. - - Returns - ------- - the reconstruced image. - """ - device = kdata.data.device - if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) - - operator = self._self_adjoint_operator().to(device) - right_hand_side = self._right_hand_side(kdata) - - img_tensor = cg( - operator, right_hand_side, initial_value=right_hand_side, max_iterations=self.n_iterations, tolerance=0.0 - ) - img = IData.from_tensor_and_kheader(img_tensor, kdata.header) - return img + super().__init__(kdata, fourier_op, csm, noise, dcf, n_iterations=n_iterations, regularization_weight=0) diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index 127be8d5f..c4208157e 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -101,15 +101,13 @@ def direct_reconstruction(self, kdata: KData) -> IData: ------- image data """ - device = kdata.data.device if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) + kdata = prewhiten_kspace(kdata, self.noise) operator = self.fourier_op if self.csm is not None: operator = operator @ self.csm.as_operator() if self.dcf is not None: operator = self.dcf.as_operator() @ operator - operator = operator.to(device) (img_tensor,) = operator.H(kdata.data) img = IData.from_tensor_and_kheader(img_tensor, kdata.header) return img diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py new file mode 100644 index 000000000..c9a307ebe --- /dev/null +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -0,0 +1,139 @@ +"""Regularized Iterative SENSE Reconstruction by adjoint Fourier transform.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace +from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.data._kdata.KData import KData +from mrpro.data.CsmData import CsmData +from mrpro.data.DcfData import DcfData +from mrpro.data.IData import IData +from mrpro.data.KNoise import KNoise +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator + + +class RegularizedIterativeSENSEReconstruction(DirectReconstruction): + r"""Regularized iterative SENSE reconstruction. + + This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2 + + \frac{1}{2}L||Bx - x_0||_2^2` + by using a conjugate gradient algorithm to solve + :math:`H x = b` with :math:`H = A^H W A + L B` and :math:`b = A^H W y + L x_0` where :math:`A` + is the acquisition model (coil sensitivity maps, Fourier operator, k-space sampling), :math:`y` is the acquired + k-space data, :math:`W` describes the density compensation, :math:`L` is the strength of the regularization and + :math:`x_0` is the regularization image (i.e. the prior). :math:`B` is a linear operator applied to :math:`x`. + """ + + n_iterations: int + """Number of CG iterations.""" + + regularization_data: torch.Tensor + """Regularization data (i.e. prior) :math:`x_0`.""" + + regularization_weight: torch.Tensor + """Strength of the regularization :math:`L`.""" + + regularization_op: LinearOperator + """Linear operator :math:`B` applied to the current estimate in the regularization term.""" + + def __init__( + self, + kdata: KData | None = None, + fourier_op: LinearOperator | None = None, + csm: Callable | CsmData | None = CsmData.from_idata_walsh, + noise: KNoise | None = None, + dcf: DcfData | None = None, + *, + n_iterations: int = 5, + regularization_data: float | torch.Tensor = 0.0, + regularization_weight: float | torch.Tensor, + regularization_op: LinearOperator | None = None, + ) -> None: + """Initialize RegularizedIterativeSENSEReconstruction. + + For a unregularized version of the iterative SENSE algorithm the regularization_weight can be set to 0 or + IterativeSENSEReconstruction algorithm can be used. + + Parameters + ---------- + kdata + KData. If kdata is provided and fourier_op or dcf are None, then fourier_op and dcf are estimated based on + kdata. Otherwise fourier_op and dcf are used as provided. + fourier_op + Instance of the FourierOperator used for reconstruction. If None, set up based on kdata. + csm + Sensitivity maps for coil combination. If None, no coil combination is carried out, i.e. images for each + coil are returned. If a callable is provided, coil images are reconstructed using the adjoint of the + FourierOperator (including density compensation) and then sensitivity maps are calculated using the + callable. For this, kdata needs also to be provided. For examples have a look at the CsmData class + e.g. from_idata_walsh or from_idata_inati. + noise + KNoise used for prewhitening. If None, no prewhitening is performed + dcf + K-space sampling density compensation. If None, set up based on kdata. + n_iterations + Number of CG iterations + regularization_data + Regularization data, e.g. a reference image (:math:`x_0`). + regularization_weight + Strength of the regularization (:math:`L`). + regularization_op + Linear operator :math:`B` applied to the current estimate in the regularization term. If None, nothing is + applied to the current estimate. + + + Raises + ------ + ValueError + If the kdata and fourier_op are None or if csm is a Callable but kdata is None. + """ + super().__init__(kdata, fourier_op, csm, noise, dcf) + self.n_iterations = n_iterations + self.regularization_data = torch.as_tensor(regularization_data) + self.regularization_weight = torch.as_tensor(regularization_weight) + self.regularization_op = regularization_op if regularization_op is not None else IdentityOp() + + def forward(self, kdata: KData) -> IData: + """Apply the reconstruction. + + Parameters + ---------- + kdata + k-space data to reconstruct. + + Returns + ------- + the reconstruced image. + """ + if self.noise is not None: + kdata = prewhiten_kspace(kdata, self.noise) + + # Create the normal operator as H = A^H W A if the DCF is not None otherwise as H = A^H A. + # The acquisition model is A = F S if the CSM S is defined otherwise A = F with the Fourier operator F + csm_op = self.csm.as_operator() if self.csm is not None else IdentityOp() + precondition_op = self.dcf.as_operator() if self.dcf is not None else IdentityOp() + operator = (self.fourier_op @ csm_op).H @ precondition_op @ (self.fourier_op @ csm_op) + + # Calculate the right-hand-side as b = A^H W y if the DCF is not None otherwise as b = A^H y. + (right_hand_side,) = (self.fourier_op @ csm_op).H(precondition_op(kdata.data)[0]) + + # Add regularization + if not torch.all(self.regularization_weight == 0): + operator = operator + IdentityOp() @ (self.regularization_weight * self.regularization_op) + right_hand_side += self.regularization_weight * self.regularization_data + + img_tensor = cg( + operator, + right_hand_side, + initial_value=right_hand_side, + max_iterations=self.n_iterations, + tolerance=0.0, + ) + img = IData.from_tensor_and_kheader(img_tensor, kdata.header) + return img diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 1d292cf37..38b539f8b 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,4 +1,10 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import RegularizedIterativeSENSEReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction -__all__ = ["Reconstruction", "DirectReconstruction", "IterativeSENSEReconstruction"] +__all__ = [ + "DirectReconstruction", + "IterativeSENSEReconstruction", + "Reconstruction", + "RegularizedIterativeSENSEReconstruction" +] diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 83f752a57..f5d677f97 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,28 +1,29 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass import ismrmrd import numpy as np import torch -from typing_extensions import Self, TypeVar +from einops import rearrange +from typing_extensions import Self from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.unit_conversion import mm_to_m -# Conversion functions for units -T = TypeVar('T', float, torch.Tensor) +def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object: + """Change the shape of the fields in AcqInfo.""" + if isinstance(field, Rotation): + return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths)) -def ms_to_s(ms: T) -> T: - """Convert ms to s.""" - return ms / 1000 + if isinstance(field, torch.Tensor): + return rearrange(field, pattern, **axes_lengths) - -def mm_to_m(m: T) -> T: - """Convert mm to m.""" - return m / 1000 + return field @dataclass(slots=True) @@ -121,30 +122,24 @@ class AcqInfo(MoveDataMixin): number_of_samples: torch.Tensor """Number of sample points per readout (readouts may have different number of sample points).""" + orientation: Rotation + """Rotation describing the orientation of the readout, phase and slice encoding direction.""" + patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - phase_dir: SpatialDimension[torch.Tensor] - """Directional cosine of phase encoding (2D).""" - physiology_time_stamp: torch.Tensor """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] """Center of the excited volume, in LPS coordinates relative to isocenter [m].""" - read_dir: SpatialDimension[torch.Tensor] - """Directional cosine of readout/frequency encoding.""" - sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - slice_dir: SpatialDimension[torch.Tensor] - """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" @@ -206,17 +201,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return ( - SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion) - ) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -251,14 +242,16 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), - phase_dir=spatialdimension_2d(headers['phase_dir']), + orientation=Rotation.from_directions( + spatialdimension_2d(headers['slice_dir']), + spatialdimension_2d(headers['phase_dir']), + spatialdimension_2d(headers['read_dir']), + ), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), - read_dir=spatialdimension_2d(headers['read_dir']), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - slice_dir=spatialdimension_2d(headers['slice_dir']), trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above user_float=tensor_2d(headers['user_float']), user_int=tensor_2d(headers['user_int']), diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 488c1d9cd..dea12e28b 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -13,12 +13,12 @@ from typing_extensions import Self from mrpro.data import enums -from mrpro.data.AcqInfo import AcqInfo, mm_to_m, ms_to_s +from mrpro.data.AcqInfo import AcqInfo from mrpro.data.EncodingLimits import EncodingLimits from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues +from mrpro.utils.unit_conversion import mm_to_m, ms_to_s if TYPE_CHECKING: # avoid circular imports by importing only when type checking @@ -40,9 +40,6 @@ class KHeader(MoveDataMixin): trajectory: KTrajectoryCalculator """Function to calculate the k-space trajectory.""" - b0: float - """Magnetic field strength [T].""" - encoding_limits: EncodingLimits """K-space encoding limits.""" @@ -61,12 +58,9 @@ class KHeader(MoveDataMixin): acq_info: AcqInfo """Information of the acquisitions (i.e. readout lines).""" - h1_freq: float + lamor_frequency_proton: float """Lamor frequency of hydrogen nuclei [Hz].""" - n_coils: int | None = None - """Number of receiver coils.""" - datetime: datetime.datetime | None = None """Date and time of acquisition.""" @@ -88,7 +82,7 @@ class KHeader(MoveDataMixin): echo_train_length: int = 1 """Number of echoes in a multi-echo acquisition.""" - seq_type: str = UNKNOWN + sequence_type: str = UNKNOWN """Type of sequence.""" model: str = UNKNOWN @@ -100,16 +94,13 @@ class KHeader(MoveDataMixin): protocol_name: str = UNKNOWN """Name of the acquisition protocol.""" - misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! - """Dictionary with miscellaneous parameters.""" - calibration_mode: enums.CalibrationMode = enums.CalibrationMode.OTHER """Mode of how calibration data is acquired. """ interleave_dim: enums.InterleavingDimension = enums.InterleavingDimension.OTHER """Interleaving dimension.""" - traj_type: enums.TrajectoryType = enums.TrajectoryType.OTHER + trajectory_type: enums.TrajectoryType = enums.TrajectoryType.OTHER """Type of trajectory.""" measurement_id: str = UNKNOWN @@ -118,8 +109,9 @@ class KHeader(MoveDataMixin): patient_name: str = UNKNOWN """Name of the patient.""" - trajectory_description: TrajectoryDescription = dataclasses.field(default_factory=TrajectoryDescription) - """Description of the trajectory.""" + _misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! + """Dictionary with miscellaneous parameters. These parameters are for information purposes only. Reconstruction + algorithms should not rely on them.""" @property def fa_degree(self) -> torch.Tensor | None: @@ -160,17 +152,14 @@ def from_ismrmrd( enc: ismrmrdschema.encodingType = header.encoding[encoding_number] # These are guaranteed to exist - parameters = {'h1_freq': header.experimentalConditions.H1resonanceFrequency_Hz, 'acq_info': acq_info} + parameters = { + 'lamor_frequency_proton': header.experimentalConditions.H1resonanceFrequency_Hz, + 'acq_info': acq_info, + } if defaults is not None: parameters.update(defaults) - if ( - header.acquisitionSystemInformation is not None - and header.acquisitionSystemInformation.receiverChannels is not None - ): - parameters['n_coils'] = header.acquisitionSystemInformation.receiverChannels - if header.sequenceParameters is not None: if header.sequenceParameters.TR: parameters['tr'] = ms_to_s(torch.as_tensor(header.sequenceParameters.TR)) @@ -184,7 +173,7 @@ def from_ismrmrd( parameters['echo_spacing'] = ms_to_s(torch.as_tensor(header.sequenceParameters.echo_spacing)) if header.sequenceParameters.sequence_type is not None: - parameters['seq_type'] = header.sequenceParameters.sequence_type + parameters['sequence_type'] = header.sequenceParameters.sequence_type if enc.reconSpace is not None: parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm).apply_(mm_to_m) @@ -212,7 +201,7 @@ def from_ismrmrd( ) if enc.trajectory is not None: - parameters['traj_type'] = enums.TrajectoryType(enc.trajectory.value) + parameters['trajectory_type'] = enums.TrajectoryType(enc.trajectory.value) # Either use the series or study time if available if header.measurementInformation is not None and header.measurementInformation.seriesTime is not None: @@ -245,15 +234,8 @@ def from_ismrmrd( if header.acquisitionSystemInformation.systemModel is not None: parameters['model'] = header.acquisitionSystemInformation.systemModel - if header.acquisitionSystemInformation.systemFieldStrength_T is not None: - parameters['b0'] = header.acquisitionSystemInformation.systemFieldStrength_T - - # estimate b0 from h1_freq if not given - if 'b0' not in parameters: - parameters['b0'] = parameters['h1_freq'] / 4258e4 - # Dump everything into misc - parameters['misc'] = dataclasses.asdict(header) + parameters['_misc'] = dataclasses.asdict(header) if overwrite is not None: parameters.update(overwrite) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..8d977d0a6 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import ClassVar, TypeAlias +from typing import ClassVar, TypeAlias, cast import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -219,26 +223,62 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) return new + def apply_( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index b5f3dfd27..46b1db89a 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -3,14 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from copy import deepcopy from dataclasses import dataclass from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike -from typing_extensions import Any, Protocol, TypeVar, overload +from typing_extensions import Protocol, Self, TypeVar, overload import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin @@ -109,6 +108,16 @@ def from_array_zyx( return SpatialDimension(z, y, x) + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply_(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" @@ -134,48 +143,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z in-place. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - if func is not None: - self.z = func(self.z) - self.y = func(self.y) - self.x = func(self.x) - return self - - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - - def func_(x: Any) -> T_co: # noqa: ANN401 - if isinstance(x, torch.Tensor): - # use clone for autograd - x = x.clone() - else: - x = deepcopy(x) - if func is None: - return x - else: - return func(x) - - return self.__class__(func_(self.z), func_(self.y), func_(self.x)) - - def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: - """Return a deep copy of the SpatialDimension.""" - return self.apply() - @overload def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... diff --git a/src/mrpro/data/TrajectoryDescription.py b/src/mrpro/data/TrajectoryDescription.py deleted file mode 100644 index 801811005..000000000 --- a/src/mrpro/data/TrajectoryDescription.py +++ /dev/null @@ -1,29 +0,0 @@ -"""TrajectoryDescription dataclass.""" - -import dataclasses -from dataclasses import dataclass - -from ismrmrd.xsd.ismrmrdschema.ismrmrd import trajectoryDescriptionType -from typing_extensions import Self - - -@dataclass(slots=True) -class TrajectoryDescription: - """TrajectoryDescription dataclass.""" - - identifier: str = '' - user_parameter_long: dict[str, int] = dataclasses.field(default_factory=dict) - user_parameter_double: dict[str, float] = dataclasses.field(default_factory=dict) - user_parameter_string: dict[str, str] = dataclasses.field(default_factory=dict) - comment: str = '' - - @classmethod - def from_ismrmrd(cls, trajectory_description: trajectoryDescriptionType) -> Self: - """Create TrajectoryDescription from ismrmrd traj description.""" - return cls( - user_parameter_long={p.name: int(p.value) for p in trajectory_description.userParameterLong}, - user_parameter_double={p.name: float(p.value) for p in trajectory_description.userParameterDouble}, - user_parameter_string={p.name: str(p.value) for p in trajectory_description.userParameterString}, - comment=trajectory_description.comment or '', - identifier=trajectory_description.identifier or '', - ) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index a954ca1b3..d5667a5bc 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -16,5 +16,27 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription -__all__ = ["enums", "traj_calculators", "acq_filters", "AcqIdx", "AcqInfo", "CsmData", "Data", "DcfData", "EncodingLimits", "Limits", "IData", "IHeader", "KData", "KHeader", "KNoise", "KTrajectory", "KTrajectoryRawShape", "MoveDataMixin", "QData", "QHeader", "Rotation", "SpatialDimension", "TrajectoryDescription"] +__all__ = [ + "AcqIdx", + "AcqInfo", + "CsmData", + "Data", + "DcfData", + "EncodingLimits", + "IData", + "IHeader", + "KData", + "KHeader", + "KNoise", + "KTrajectory", + "KTrajectoryRawShape", + "Limits", + "MoveDataMixin", + "QData", + "QHeader", + "Rotation", + "SpatialDimension", + "acq_filters", + "enums", + "traj_calculators" +] diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 409e8aac9..aaf430497 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -17,16 +17,16 @@ from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin -from mrpro.data.acq_filters import is_image_acquisition -from mrpro.data.AcqInfo import AcqInfo +from mrpro.data.acq_filters import has_n_coils, is_image_acquisition +from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd -from mrpro.utils import modify_acq_info KDIM_SORT_LABELS = ( 'k1', @@ -110,6 +110,29 @@ def from_file( modification_time = datetime.datetime.fromtimestamp(mtime) acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)] + + # we need the same number of receiver coils for all acquisitions + n_coils_available = {acq.data.shape[0] for acq in acquisitions} + if len(n_coils_available) > 1: + if ( + ismrmrd_header.acquisitionSystemInformation is not None + and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None + ): + n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels) + else: + # most likely, highest number of elements are the coils used for imaging + n_coils = int(max(n_coils_available)) + + warnings.warn( + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. ' + f'Data with {n_coils} receiver coil elements will be used.', + stacklevel=1, + ) + acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] + + if not acquisitions: + raise ValueError('No acquisitions meeting the given filter criteria were found.') + kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) @@ -200,10 +223,13 @@ def from_file( sort_idx = np.lexsort(acq_indices) # torch does not have lexsort as of pytorch 2.2 (March 2024) # Finally, reshape and sort the tensors in acqinfo and acqinfo.idx, and kdata. - def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): - return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) - - kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2 + ) + if isinstance(field, torch.Tensor | Rotation) + else field + ) kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py index 05bab7681..23a58dea6 100644 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/KDataRearrangeMixin.py @@ -6,8 +6,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import AcqInfo -from mrpro.utils import modify_acq_info +from mrpro.data.AcqInfo import rearrange_acq_info_fields class KDataRearrangeMixin(_KDataProtocol): @@ -35,9 +34,8 @@ def rearrange_k2_k1_into_k1(self: Self) -> Self: kheader = copy.deepcopy(self.header) # Update shape of acquisition info index - def reshape_acq_info(info: AcqInfo): - return rearrange(info, 'other k2 k1 ... -> other 1 (k2 k1) ...') - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py index fb0e02aa1..8f8a452cf 100644 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/KDataSelectMixin.py @@ -7,7 +7,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation class KDataSelectMixin(_KDataProtocol): @@ -51,10 +51,9 @@ def select_other_subset( other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) # Adapt header - def select_acq_info(info: torch.Tensor): - return info[other_idx, ...] - - kheader.acq_info = modify_acq_info(select_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) # Select data kdat = self.data[other_idx, ...] diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py index d2c641125..c28004af4 100644 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/KDataSplitMixin.py @@ -1,15 +1,17 @@ """Mixin class to split KData into other subsets.""" -import copy -from typing import Literal +from typing import Literal, TypeVar, cast import torch from einops import rearrange, repeat from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation + +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) class KDataSplitMixin(_KDataProtocol): @@ -56,8 +58,9 @@ def _split_k2_or_k1_into_other( def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, :, split_idx, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, :, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' @@ -69,8 +72,8 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, split_idx, :, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' @@ -93,13 +96,14 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) # Create new header with correct shape - kheader = copy.deepcopy(self.header) + kheader = self.header.clone() # Update shape of acquisition info index - def reshape_acq_info(info: torch.Tensor): - return rearrange(split_acq_info(info), rearrange_pattern_acq_info) - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) # Update other label limits and acquisition info setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) diff --git a/src/mrpro/data/acq_filters.py b/src/mrpro/data/acq_filters.py index d64c4d9a6..4723d3bba 100644 --- a/src/mrpro/data/acq_filters.py +++ b/src/mrpro/data/acq_filters.py @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool: """ coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING return coil_calibration_flag.value & acquisition.flags + + +def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool: + """Test if acquisitions was obtained with a certain number of receiver coils. + + Parameters + ---------- + n_coils + number of receiver coils + acquisition + ISMRMRD acquisition + + Returns + ------- + True if the acquisition was obtained with n_coils receiver coils + """ + return acquisition.data.shape[0] == n_coils diff --git a/src/mrpro/data/traj_calculators/__init__.py b/src/mrpro/data/traj_calculators/__init__.py index 72d61536b..2fd0e5b5a 100644 --- a/src/mrpro/data/traj_calculators/__init__.py +++ b/src/mrpro/data/traj_calculators/__init__.py @@ -5,4 +5,12 @@ from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd from mrpro.data.traj_calculators.KTrajectoryPulseq import KTrajectoryPulseq from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian -__all__ = ["KTrajectoryCalculator", "KTrajectoryRpe", "KTrajectorySunflowerGoldenRpe", "KTrajectoryRadial2D", "KTrajectoryIsmrmrd", "KTrajectoryPulseq", "KTrajectoryCartesian"] +__all__ = [ + "KTrajectoryCalculator", + "KTrajectoryCartesian", + "KTrajectoryIsmrmrd", + "KTrajectoryPulseq", + "KTrajectoryRadial2D", + "KTrajectoryRpe", + "KTrajectorySunflowerGoldenRpe" +] \ No newline at end of file diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..07f8aba65 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -1,5 +1,7 @@ """Cartesian Sampling Operator.""" +import warnings + import torch from einops import rearrange, repeat @@ -7,6 +9,7 @@ from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator +from mrpro.utils.reshape import unsqueeze_left class CartesianSamplingOp(LinearOperator): @@ -47,26 +50,53 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + + # check that all points are inside the encoding matrix + inside_encoding_matrix = ( + ((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x)) + & ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y)) + & ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z)) + ) + if not torch.all(inside_encoding_matrix): + warnings.warn( + 'K-space points lie outside of the encoding_matrix and will be ignored.' + ' Increase the encoding_matrix to include these points.', + stacklevel=2, + ) + + inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)') + inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1] + inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1)) + self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx) + kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1) + else: + self._inside_encoding_matrix_idx: torch.Tensor | None = None + self.register_buffer('_fft_idx', kidx) + # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) + or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + or self._inside_encoding_matrix_idx is not None + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape @@ -91,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (x,) x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') - # take_along_dim does broadcast, so no need for extending here - x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) + # take_along_dim broadcasts, but needs the same number of dimensions + idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim) + x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1) + + if self._inside_encoding_matrix_idx is None: + # all trajectory points are inside the encoding matrix + x_indexed = x_inside_encoding_matrix + else: + # we need to add zeros + x_indexed = self._broadcast_and_scatter_along_last_dim( + x_inside_encoding_matrix, + self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3], + self._inside_encoding_matrix_idx, + ) + # reshape to (... other coil, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) @@ -118,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') - # scatter does not broadcast, so we need to manually broadcast the indices - broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) - idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1])) + if self._inside_encoding_matrix_idx is not None: + idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim) + y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1) - # although scatter_ is inplace, this will not cause issues with autograd, as self - # is always constant zero and gradients w.r.t. src work as expected. - y_scattered = torch.zeros( - *broadcast_shape, - self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, - dtype=y.dtype, - device=y.device, - ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) + y_scattered = self._broadcast_and_scatter_along_last_dim( + y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx + ) # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( @@ -140,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @staticmethod + def _broadcast_and_scatter_along_last_dim( + data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor + ) -> torch.Tensor: + """Broadcast scatter index and scatter into zero tensor. + + Parameters + ---------- + data_to_scatter + Data to be scattered at indices scatter_index + n_last_dim + Number of data points in last dimension + scatter_index + Indices describing where to scatter data + + Returns + ------- + Data scattered into tensor along scatter_index + """ + # scatter does not broadcast, so we need to manually broadcast the indices + broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1]) + idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1])) + + # although scatter_ is inplace, this will not cause issues with autograd, as self + # is always constant zero and gradients w.r.t. src work as expected. + data_scattered = torch.zeros( + *broadcast_shape, + n_last_dim, + dtype=data_to_scatter.dtype, + device=data_to_scatter.device, + ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) + + return data_scattered diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 53f3f6eb4..4ffe7e3a6 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: FFT of x """ y = torch.fft.fftshift( - torch.fft.fftn(torch.fft.ifftshift(*self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + torch.fft.fftn(torch.fft.ifftshift(*self._pad_op(x), dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim, ) return (y,) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 79351da76..c57d1fbfd 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -11,6 +11,7 @@ from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op.forward(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index f136d9d51..d919e63c6 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -102,7 +102,7 @@ def operator_norm( max_iterations: int = 20, relative_tolerance: float = 1e-4, absolute_tolerance: float = 1e-5, - callback: Callable | None = None, + callback: Callable[[torch.Tensor], None] | None = None, ) -> torch.Tensor: """Power iteration for computing the operator norm of the linear operator. @@ -294,6 +294,20 @@ def __rmul__(self, other: torch.Tensor | complex) -> LinearOperator: else: return NotImplemented # type: ignore[unreachable] + def __and__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Vertical stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self], [other]] + return mrpro.operators.LinearOperatorMatrix(operators) + + def __or__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Horizontal stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self, other]] + return mrpro.operators.LinearOperatorMatrix(operators) + @property def gram(self) -> LinearOperator: """Gram operator. diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py new file mode 100644 index 000000000..ab0673398 --- /dev/null +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -0,0 +1,363 @@ +"""Linear Operator Matrix class.""" + +from __future__ import annotations + +import operator +from collections.abc import Callable, Iterator, Sequence +from functools import reduce +from types import EllipsisType +from typing import cast + +import torch +from typing_extensions import Self, Unpack + +from mrpro.operators.LinearOperator import LinearOperator, LinearOperatorSum +from mrpro.operators.Operator import Operator +from mrpro.operators.ZeroOp import ZeroOp + +_SingleIdxType = int | slice | EllipsisType | Sequence[int] +_IdxType = _SingleIdxType | tuple[_SingleIdxType, _SingleIdxType] + + +class LinearOperatorMatrix(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): + r"""Matrix of Linear Operators. + + A matrix of Linear Operators, where each element is a Linear Operator. + + This matrix can be applied to a sequence of tensors, where the number of tensors should match + the number of columns of the matrix. The output will be a sequence of tensors, where the number + of tensors will match the number of rows of the matrix. + The i-th output tensor is calculated as + :math:`\sum_j \text{operators}[i][j](x[j])` where :math:`\text{operators}[i][j]` is the linear operator + in the i-th row and j-th column and :math:`x[j]` is the j-th input tensor. + + The matrix can be indexed and sliced like a regular matrix to get submatrices. + If indexing returns a single element, it is returned as a Linear Operator. + + Basic arithmetic operations are supported with Linear Operators and Tensors. + + """ + + _operators: list[list[LinearOperator]] + + def __init__(self, operators: Sequence[Sequence[LinearOperator]]): + """Initialize Linear Operator Matrix. + + Create a matrix of LinearOperators from a sequence of rows, where each row is a sequence + of LinearOperators that represent the columns of the matrix. + + Parameters + ---------- + operators + A sequence of rows, which are sequences of Linear Operators. + """ + if not all(isinstance(op, LinearOperator) for row in operators for op in row): + raise ValueError('All elements should be LinearOperators.') + if not all(len(row) == len(operators[0]) for row in operators): + raise ValueError('All rows should have the same length.') + super().__init__() + self._operators = cast( # cast because ModuleList is not recognized as a list + list[list[LinearOperator]], torch.nn.ModuleList(torch.nn.ModuleList(row) for row in operators) + ) + self._shape = (len(operators), len(operators[0]) if operators else 0) + + @property + def shape(self) -> tuple[int, int]: + """Shape of the Operator Matrix (rows, columns).""" + return self._shape + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has columns. + + Returns + ------- + Output tensors. The same number of tensors as the operator has rows. + """ + if len(x) != self.shape[1]: + raise ValueError('Input should be the same number of tensors as the LinearOperatorMatrix has columns.') + return tuple( + reduce(operator.add, (op(xi)[0] for op, xi in zip(row, x, strict=True))) for row in self._operators + ) + + def __getitem__(self, idx: _IdxType) -> Self | LinearOperator: + """Index the Operator Matrix. + + Parameters + ---------- + idx + Index or slice to select rows and columns. + + Returns + ------- + Subset LinearOperatorMatrix or Linear Operator. + """ + idxs: tuple[_SingleIdxType, _SingleIdxType] = idx if isinstance(idx, tuple) else (idx, slice(None)) + if len(idxs) > 2: + raise IndexError('Too many indices for LinearOperatorMatrix') + + def _to_numeric_index(idx: slice | int | Sequence[int] | EllipsisType, length: int) -> Sequence[int]: + """Convert index to a sequence of integers or raise an error.""" + if isinstance(idx, slice): + if (idx.start is not None and (idx.start < -length or idx.start >= length)) or ( + idx.stop is not None and (idx.stop < -length or idx.stop > length) + ): + raise IndexError('Index out of range') + return range(*idx.indices(length)) + if isinstance(idx, int): + if idx < -length or idx >= length: + raise IndexError('Index out of range') + return (idx,) + if idx is Ellipsis: + return range(length) + if isinstance(idx, Sequence): + if min(idx) < -length or max(idx) >= length: + raise IndexError('Index out of range') + return idx + else: + raise IndexError('Invalid index type') + + row_numbers = _to_numeric_index(idxs[0], self._shape[0]) + col_numbers = _to_numeric_index(idxs[1], self._shape[1]) + + sliced_operators = [ + [row[col_number] for col_number in col_numbers] + for row in [self._operators[row_number] for row_number in row_numbers] + ] + + # Return a single operator if only one row and column is selected + if len(row_numbers) == 1 and len(col_numbers) == 1: + return sliced_operators[0][0] + else: + return self.__class__(sliced_operators) + + def __iter__(self) -> Iterator[Sequence[LinearOperator]]: + """Iterate over the rows of the Operator Matrix.""" + return iter(self._operators) + + def __repr__(self): + """Representation of the Operator Matrix.""" + return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' + + # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. + def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + """Addition.""" + operators: list[list[LinearOperator]] = [] + if isinstance(other, LinearOperatorMatrix): + if self.shape != other.shape: + raise ValueError('OperatorMatrix shapes do not match.') + for self_row, other_row in zip(self._operators, other._operators, strict=False): + operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) + elif isinstance(other, LinearOperator | torch.Tensor): + if not self.shape[0] == self.shape[1]: + raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') + for i, self_row in enumerate(self._operators): + operators.append([op + other if i == j else op for j, op in enumerate(self_row)]) + else: + return NotImplemented # type: ignore[unreachable] + return self.__class__(operators) + + def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + """Right addition.""" + return self.__add__(other) + + def __mul__(self, other: torch.Tensor | Sequence[torch.Tensor | complex] | complex) -> Self: + """LinearOperatorMatrix*Tensor multiplication. + + Example: ([A,B]*c)(x) = [A*c, B*c](x) = A(c*x) + B(c*x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[1] + elif len(other) != self.shape[1]: + raise ValueError('Other should have the same length as the operator has columns.') + else: + other_ = other + operators = [] + for row in self._operators: + operators.append([op * o for op, o in zip(row, other_, strict=True)]) + return self.__class__(operators) + + def __rmul__(self, other: torch.Tensor | Sequence[torch.Tensor] | complex) -> Self: + """Tensor*LinearOperatorMatrix multiplication. + + Example: (c*[A,B])(x) = [c*A, c*B](x) = c*A(x) + c*B(x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[0] + elif len(other) != self.shape[0]: + raise ValueError('Other should have the same length as the operator has rows.') + else: + other_ = other + operators = [] + for row, o in zip(self._operators, other_, strict=True): + operators.append([cast(LinearOperator, o * op) for op in row]) + return self.__class__(operators) + + def __matmul__(self, other: LinearOperator | Self) -> Self: # type: ignore[override] + """Composition of operators.""" + if isinstance(other, LinearOperator): + return self._binary_operation(other, '__matmul__') + elif isinstance(other, LinearOperatorMatrix): + if self.shape[1] != other.shape[0]: + raise ValueError('OperatorMatrix shapes do not match.') + new_operators = [] + for row in self._operators: + new_row = [] + for other_col in zip(*other._operators, strict=True): + elements = [s @ o for s, o in zip(row, other_col, strict=True)] + new_row.append(LinearOperatorSum(*elements)) + new_operators.append(new_row) + return self.__class__(new_operators) + return NotImplemented # type: ignore[unreachable] + + @property + def H(self) -> Self: # noqa N802 + """Adjoints of the operators.""" + return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the adjoint of the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has rows. + + Returns + ------- + Output tensors. The same number of tensors as the operator has columns. + """ + return self.H(*x) + + @classmethod + def from_diagonal(cls, *operators: LinearOperator): + """Create a diagonal LinearOperatorMatrix. + + Create a square LinearOperatorMatrix with the given Linear Operators on the diagonal, + resulting in a block-diagonal linear operator. + + Parameters + ---------- + operators + Sequence of Linear Operators to be placed on the diagonal. + """ + operator_matrix: list[list[LinearOperator]] = [ + [op if i == j else ZeroOp(False) for j in range(len(operators))] for i, op in enumerate(operators) + ] + return cls(operator_matrix) + + def operator_norm( + self, + *initial_value: torch.Tensor, + dim: Sequence[int] | None = None, + max_iterations: int = 20, + relative_tolerance: float = 1e-4, + absolute_tolerance: float = 1e-5, + callback: Callable[[torch.Tensor], None] | None = None, + ) -> torch.Tensor: + """Upper bound of operator norm of the Matrix. + + Uses the bounds :math:`||[A, B]^T|||<=sqrt(||A||^2 + ||B||^2)` and :math:`||[A, B]|||<=max(||A||,||B||)` + to estimate the operator norm of the matrix. + First, operator_norm is called on each element of the matrix. + Next, the norm is estimated for each column using the first bound. + Finally, the norm of the full matrix of linear operators is calculated using the second bound. + + Parameters + ---------- + initial_value + Initial value(s) for the power iteration, length should match the number of columns + of the operator matrix. + dim + Dimensions to calculate the operator norm over. Other dimensions are assumed to be + batch dimensions. None means all dimensions. + max_iterations + Maximum number of iterations used in the power iteration. + relative_tolerance + Relative tolerance for convergence. + absolute_tolerance + Absolute tolerance for convergence. + callback + Callback function to be called with the current estimate of the operator norm. + + + Returns + ------- + Estimated operator norm upper bound. + """ + + def _singlenorm(op: LinearOperator, initial_value: torch.Tensor): + return op.operator_norm( + initial_value, + dim=dim, + max_iterations=max_iterations, + relative_tolerance=relative_tolerance, + absolute_tolerance=absolute_tolerance, + callback=callback, + ) + + if len(initial_value) != self.shape[1]: + raise ValueError('Initial value should have the same length as the operator has columns.') + norms = torch.tensor( + [[_singlenorm(op, iv) for op, iv in zip(row, initial_value, strict=True)] for row in self._operators] + ) + norm = norms.square().sum(-2).sqrt().amax(-1).unsqueeze(-1) + return norm + + def __or__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Vertical stacking.""" + if isinstance(other, LinearOperator): + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking : cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[*self._operators[0], other]] + return self.__class__(operators) + else: + if (rows_self := self.shape[0]) != (rows_other := other.shape[0]): + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack matrices with {rows_self} and {rows_other}.' + ) + operators = [[*self_row, *other_row] for self_row, other_row in zip(self, other, strict=True)] + return self.__class__(operators) + + def __ror__(self, other: LinearOperator) -> Self: + """Vertical stacking.""" + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[other, *self._operators[0]]] + return self.__class__(operators) + + def __and__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Horizontal stacking.""" + if isinstance(other, LinearOperator): + if cols := self.shape[1] > 1: + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [*self._operators, [other]] + return self.__class__(operators) + else: + if (cols_self := self.shape[1]) != (cols_other := other.shape[1]): + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack matrices with {cols_self} and {cols_other} columns.' + ) + operators = [*self._operators, *other] + return self.__class__(operators) + + def __rand__(self, other: LinearOperator) -> Self: + """Horizontal stacking.""" + if cols := self.shape[1] > 1: + raise ValueError( + f'Shape mismatch in horizontal stacking: cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [[other], *self._operators] + return self.__class__(operators) diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 4fe58f1e3..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -11,8 +11,10 @@ from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -23,8 +25,6 @@ from mrpro.operators.ZeroOp import ZeroOp __all__ = [ - "functionals", - "models", "CartesianSamplingOp", "ConstraintsOp", "DensityCompensationOp", @@ -38,15 +38,22 @@ "GridSamplingOp", "IdentityOp", "LinearOperator", + "LinearOperatorMatrix", "MagnitudeOp", + "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", + "ScaledFunctional", + "ScaledProximableFunctional", "SensitivityOp", "SignalModel", "SliceProjectionOp", "WaveletOp", "ZeroOp", "ZeroPadOp", -] + "functionals", + "models" +] \ No newline at end of file diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index 0ec3cc517..629a560d3 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -5,4 +5,12 @@ from mrpro.operators.models.WASABITI import WASABITI from mrpro.operators.models.MonoExponentialDecay import MonoExponentialDecay from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation -__all__ = ["SaturationRecovery", "InversionRecovery", "MOLLI", "WASABI", "WASABITI", "MonoExponentialDecay", "TransientSteadyStateWithPreparation"] +__all__ = [ + "InversionRecovery", + "MOLLI", + "MonoExponentialDecay", + "SaturationRecovery", + "TransientSteadyStateWithPreparation", + "WASABI", + "WASABITI" +] \ No newline at end of file diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index 3c9046364..081d2e465 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -1,3 +1,3 @@ from mrpro.phantoms.EllipsePhantom import EllipsePhantom from mrpro.phantoms.phantom_elements import EllipseParameters -__all__ = ["EllipsePhantom", "EllipseParameters"] +__all__ = ["EllipseParameters", "EllipsePhantom"] \ No newline at end of file diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index ce3daf2c8..80ef9d398 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,9 +1,23 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing +import mrpro.utils.unit_conversion from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop -from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right -__all__ = ["slice_profiles", "typing", "smap", "remove_repeat", "zero_pad_or_crop", "modify_acq_info", "split_idx", "broadcast_right", "unsqueeze_left", "unsqueeze_right"] +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +import mrpro.utils.unit_conversion + +__all__ = [ + "broadcast_right", + "reduce_view", + "remove_repeat", + "slice_profiles", + "smap", + "split_idx", + "typing", + "unit_conversion", + "unsqueeze_left", + "unsqueeze_right", + "zero_pad_or_crop" +] diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py deleted file mode 100644 index d535e53c5..000000000 --- a/src/mrpro/utils/modify_acq_info.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Modify AcqInfo.""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import TYPE_CHECKING - -import torch - -if TYPE_CHECKING: - from mrpro.data.AcqInfo import AcqInfo - - -def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: - """Go through all fields of AcqInfo object and apply changes. - - Parameters - ---------- - fun_modify - Function which takes AcqInfo fields as input and returns modified AcqInfo field - acq_info - AcqInfo object - """ - # Apply function to all fields of acq_info - for field in dataclasses.fields(acq_info): - current = getattr(acq_info, field.name) - if isinstance(current, torch.Tensor): - setattr(acq_info, field.name, fun_modify(current)) - elif dataclasses.is_dataclass(current): - for subfield in dataclasses.fields(current): - subcurrent = getattr(current, subfield.name) - setattr(current, subfield.name, fun_modify(subcurrent)) - - return acq_info diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 39d12e51f..31d495afd 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,5 +1,7 @@ """Tensor reshaping utilities.""" +from collections.abc import Sequence + import torch @@ -67,3 +69,33 @@ def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]: max_dim = max(el.ndim for el in x) unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x)) return unsqueezed + + +def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor: + """Reduce expanded dimensions in a view to singletons. + + Reduce either all or specific dimensions to a singleton if it + points to the same memory address. + This undoes expand. + + Parameters + ---------- + x + input tensor + dim + only reduce expanded dimensions in the specified dimensions. + If None, reduce all expanded dimensions. + """ + if dim is None: + dim_: Sequence[int] = range(x.ndim) + elif isinstance(dim, Sequence): + dim_ = [d % x.ndim for d in dim] + else: + dim_ = [dim % x.ndim] + + stride = x.stride() + newsize = [ + 1 if stride == 0 and d in dim_ else oldsize + for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) + ] + return torch.as_strided(x, newsize, stride) diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py new file mode 100644 index 000000000..0115bed47 --- /dev/null +++ b/src/mrpro/utils/unit_conversion.py @@ -0,0 +1,94 @@ +"""Conversion between different units.""" + +from typing import TypeVar + +import numpy as np +import torch + +__all__ = [ + 'ms_to_s', + 's_to_ms', + 'mm_to_m', + 'm_to_mm', + 'deg_to_rad', + 'rad_to_deg', + 'lamor_frequency_to_magnetic_field', + 'magnetic_field_to_lamor_frequency', + 'GYROMAGNETIC_RATIO_PROTON', +] + +GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 +r"""The gyromagnetic ratio :math:`\frac{\gamma}{2\pi}` of 1H in H20 in Hz/T""" + +# Conversion functions for units +T = TypeVar('T', float, torch.Tensor) + + +def ms_to_s(ms: T) -> T: + """Convert ms to s.""" + return ms / 1000 + + +def s_to_ms(s: T) -> T: + """Convert s to ms.""" + return s * 1000 + + +def mm_to_m(mm: T) -> T: + """Convert mm to m.""" + return mm / 1000 + + +def m_to_mm(m: T) -> T: + """Convert m to mm.""" + return m * 1000 + + +def deg_to_rad(deg: T) -> T: + """Convert degree to radians.""" + if isinstance(deg, torch.Tensor): + return torch.deg2rad(deg) + return deg / 180.0 * np.pi + + +def rad_to_deg(deg: T) -> T: + """Convert radians to degree.""" + if isinstance(deg, torch.Tensor): + return torch.rad2deg(deg) + return deg * 180.0 / np.pi + + +def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: + """Convert the Lamor frequency [Hz] to the magntic field strength [T]. + + Parameters + ---------- + lamor_frequency + Lamor frequency [Hz] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Magnetic field strength [T] + """ + return lamor_frequency / gyromagnetic_ratio + + +def magnetic_field_to_lamor_frequency( + magnetic_field_strength: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON +) -> T: + """Convert the magntic field strength [T] to Lamor frequency [Hz]. + + Parameters + ---------- + magnetic_field_strength + Strength of the magnetic field [T] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Lamor frequency [Hz] + """ + return magnetic_field_strength * gyromagnetic_ratio diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 16c16e78b..4abda2a98 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -16,7 +16,8 @@ (1, 32, False), (4, 32, True), (4, 32, False), - ] + ], + ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'], ) def system(request): """Generate data for creating a system Hx=b with linear and self-adjoint @@ -145,3 +146,13 @@ def callback(cg_status: CGStatus) -> None: assert True cg(h_operator, right_hand_side, callback=callback) + + +def test_autograd(system): + """Test autograd through cg""" + h_operator, right_hand_side, _ = system + right_hand_side.requires_grad_(True) + with torch.autograd.detect_anomaly(): + result = cg(h_operator, right_hand_side, tolerance=0, max_iterations=5) + result.abs().sum().backward() + assert right_hand_side.grad is not None diff --git a/tests/conftest.py b/tests/conftest.py index e3f943462..30ae9c229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -45,9 +46,9 @@ def generate_random_acquisition_properties(generator: RandomGenerator): 'encoding_space_ref': generator.uint16(), 'sample_time_us': generator.float32(), 'position': generator.float32_tuple(3), - 'read_dir': generator.float32_tuple(3), - 'phase_dir': generator.float32_tuple(3), - 'slice_dir': generator.float32_tuple(3), + 'read_dir': (1, 0, 0), # read, phase and slice have to form rotation + 'phase_dir': (0, 1, 0), + 'slice_dir': (0, 0, 1), 'patient_table_position': generator.float32_tuple(3), 'idx': ismrmrd.EncodingCounters(**idx_properties), 'user_int': generator.uint32_tuple(8), @@ -233,178 +234,193 @@ def create_uniform_traj(nk, k_shape): return k -def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): +def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Create trajectory with random entries.""" random_generator = RandomGenerator(seed=0) k_list = [] - for spacing, nk in zip([sz, sy, sx], [nkz, nky, nkx], strict=True): - if spacing == 'nuf': - k = random_generator.float32_tensor(size=nk) - elif spacing == 'uf': + for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): + if spacing == 'non-uniform': + k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk) + elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) - elif spacing == 'z': + elif spacing == 'zero': k = torch.zeros(nk) k_list.append(k) trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ - # (0) 2d cart mri with 1 coil, no oversampling - ( - (1, 1, 1, 96, 128), # img shape - (1, 1, 1, 96, 128), # k shape - (1, 1, 1, 128), # kx - (1, 1, 96, 1), # ky - (1, 1, 1, 1), # kz - 'uf', # kx is uniform - 'uf', # ky is uniform - 'z', # zero so no Fourier transform is performed along that dimension - 'uf', # k0 is uniform - 'uf', # k1 is uniform - 'z', # k2 is singleton + ( # (0) 2d Cartesian single coil, no oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 96, 128), # k_shape + (1, 1, 1, 128), # nkx + (1, 1, 96, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (1) 2d cart mri with 1 coil, with oversampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 128, 192), - (1, 1, 1, 192), - (1, 1, 128, 1), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (1) 2d Cartesian single coil, with oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 128, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 128, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (2) 2d non-Cartesian mri with 2 coils - ( - (1, 2, 1, 96, 128), - (1, 2, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 1, 1), - 'nuf', # kx is non-uniform - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (2) 2d non-Cartesian mri with 2 coils + (1, 2, 1, 96, 128), # im_shape + (1, 2, 1, 16, 192), # k_shape + (1, 1, 16, 192), # nkx + (1, 1, 16, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (3) 2d cart mri with irregular sampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'z', - 'z', + ( # (3) 2d Cartesian with irregular sampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (4) 2d single shot spiral - ( - (1, 2, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'z', - 'z', + ( # (4) 2d single shot spiral + (1, 2, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (5) 3d nuFFT mri, 4 coils, 2 other - ( - (2, 4, 16, 32, 64), - (2, 4, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (5) 3d non-uniform, 4 coils, 2 other + (2, 4, 16, 32, 64), # im_shape + (2, 4, 16, 32, 64), # k_shape + (2, 16, 32, 64), # nkx + (2, 16, 32, 64), # nky + (2, 16, 32, 64), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (6) 2d nuFFT cine mri with 8 cardiac phases, 5 coils - ( - (8, 5, 1, 64, 64), - (8, 5, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (6) 2d non-uniform cine with 8 cardiac phases, 5 coils + (8, 5, 1, 64, 64), # im_shape + (8, 5, 1, 18, 128), # k_shape + (8, 1, 18, 128), # nkx + (8, 1, 18, 128), # nky + (8, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (7) 2d cart cine mri with 9 cardiac phases, 6 coils - ( - (9, 6, 1, 96, 128), - (9, 6, 1, 128, 192), - (9, 1, 1, 192), - (9, 1, 128, 1), - (9, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (7) 2d cartesian cine with 9 cardiac phases, 6 coils + (9, 6, 1, 96, 128), # im_shape + (9, 6, 1, 128, 192), # k_shape + (9, 1, 1, 192), # nkx + (9, 1, 128, 1), # nky + (9, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and nuFFT directions - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'uf', - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', + ( # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and non-uniform directions + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (9) radial phase encoding (RPE) , 8 coils with non-Cartesian sampling along readout - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (9) radial phase encoding (RPE), 8 coils with non-Cartesian sampling along readout + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and nuFFT directions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', - 'uf', + ( # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and non-uniform directions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 1, 18, 64), # nky + (5, 96, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'uniform', # type_k2 ), ], + ids=[ + '2d_cartesian_1_coil_no_oversampling', + '2d_cartesian_1_coil_with_oversampling', + '2d_non_cartesian_mri_2_coils', + '2d_cartesian_irregular_sampling', + '2d_single_shot_spiral', + '3d_nonuniform_4_coils_2_other', + '2d_nnonuniform_cine_mri_8_cardiac_phases_5_coils', + '2d_cartesian_cine_9_cardiac_phases_6_coils', + 'radial_phase_encoding_8_coils_with_oversampling', + 'radial_phase_encoding_8_coils_non_cartesian_sampling', + 'stack_of_stars_5_other_3_coil_with_oversampling', + ], ) diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 59c42be15..181be56d6 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -67,6 +67,7 @@ def __init__( trajectory_type: Literal['cartesian', 'radial'] = 'cartesian', sampling_order: Literal['linear', 'low_high', 'high_low', 'random'] = 'linear', phantom: EllipsePhantom | None = None, + add_bodycoil_acquisitions: bool = False, n_separate_calibration_lines: int = 0, ): if not phantom: @@ -222,23 +223,32 @@ def __init__( acq.phase_dir[1] = 1.0 acq.slice_dir[2] = 1.0 - # Initialize an acquisition counter - counter = 0 + scan_counter = 0 # Write out a few noise scans for _ in range(32): noise = self.noise_level * torch.randn(self.n_coils, n_freq_encoding, dtype=torch.complex64) # here's where we would make the noise correlated - acq.scan_counter = counter + acq.scan_counter = scan_counter acq.clearAllFlags() acq.setFlag(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) acq.data[:] = noise.numpy() dataset.append_acquisition(acq) - counter += 1 # increment the scan counter + scan_counter += 1 + + # Add acquisitions obtained with a 2-element body coil (e.g. used for adjustment scans) + if add_bodycoil_acquisitions: + acq.resize(n_freq_encoding, 2, trajectory_dimensions=2) + for _ in range(8): + acq.scan_counter = scan_counter + acq.clearAllFlags() + acq.data[:] = torch.randn(2, n_freq_encoding, dtype=torch.complex64) + dataset.append_acquisition(acq) + scan_counter += 1 + acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) # Calibration lines if n_separate_calibration_lines > 0: - # we take calibration lines around the k-space center traj_ky_calibration, traj_kx_calibration, kpe_calibration = self._cartesian_trajectory( n_separate_calibration_lines, n_freq_encoding, @@ -253,7 +263,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe_calibration): # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -264,7 +274,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_calibration[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Loop over the repetitions, add noise and write to disk for rep in range(self.repetitions): @@ -275,7 +285,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe[rep]): if not self.flag_invalid_reps or rep == 0 or pe_idx < len(kpe[rep]) // 2: # fewer lines for rep > 0 # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -298,7 +308,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_with_noise[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Clean up dataset.close() diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 822c63045..d5cfa0f0c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -2,12 +2,13 @@ import pytest import torch -from einops import rearrange, repeat +from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension -from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp -from mrpro.utils import modify_acq_info, split_idx +from mrpro.utils import split_idx from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData @@ -16,14 +17,15 @@ @pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" +def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set with bodycoil and surface coil data.""" ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' ismrmrd_kdata = IsmrmrdRawTestData( filename=ismrmrd_filename, noise_level=0.0, repetitions=3, phantom=ellipse_phantom.phantom, + add_bodycoil_acquisitions=True, ) return ismrmrd_kdata @@ -77,10 +79,11 @@ def consistently_shaped_kdata(request, random_kheader_shape): # Start with header kheader, n_other, n_coils, n_k2, n_k1, n_k0 = random_kheader_shape - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) # Create kdata with consistent shape kdata = generate_random_data(RandomGenerator(request.param['seed']), (n_other, n_coils, n_k2, n_k1, n_k0)) @@ -124,6 +127,34 @@ def test_KData_raise_wrong_trajectory_shape(ismrmrd_cart): _ = KData.from_file(ismrmrd_cart.filename, trajectory) +def test_KData_raise_warning_for_bodycoil(ismrmrd_cart_bodycoil_and_surface_coil): + """Mix of bodycoil and surface coil acquisitions leads to warning.""" + with pytest.raises(UserWarning, match='Acquisitions with different number'): + _ = KData.from_file(ismrmrd_cart_bodycoil_and_surface_coil.filename, DummyTrajectory()) + + +@pytest.mark.filterwarnings('ignore:Acquisitions with different number:UserWarning') +def test_KData_select_bodycoil_via_filter(ismrmrd_cart_bodycoil_and_surface_coil): + """Bodycoil can be selected via a custom acquisition filter.""" + # This is the recommended way of selecting the body coil (i.e. 2 receiver elements) + kdata = KData.from_file( + ismrmrd_cart_bodycoil_and_surface_coil.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + assert kdata.data.shape[-4] == 2 + + +def test_KData_raise_wrong_coil_number(ismrmrd_cart): + """Wrong number of coils leads to empty acquisitions.""" + with pytest.raises(ValueError, match='No acquisitions meeting the given filter criteria were found'): + _ = KData.from_file( + ismrmrd_cart.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + + def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): """Multiple repetitions with different number of phase encoding lines.""" with pytest.warns(UserWarning, match=r'different number'): @@ -162,7 +193,7 @@ def test_KData_kspace(ismrmrd_cart): assert relative_image_difference(reconstructed_img[0, 0, 0, ...], ismrmrd_cart.img_ref) <= 0.05 -@pytest.mark.parametrize(('field', 'value'), [('b0', 11.3), ('tr', torch.tensor([24.3]))]) +@pytest.mark.parametrize(('field', 'value'), [('lamor_frequency_proton', 42.88 * 1e6), ('tr', torch.tensor([24.3]))]) def test_KData_modify_header(ismrmrd_cart, field, value): """Overwrite some parameters in the header.""" parameter_dict = {field: value} @@ -469,3 +500,21 @@ def test_KData_remove_readout_os(monkeypatch, random_kheader): # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. assert relative_image_difference(torch.abs(img_recon), img_tensor[:, 0, ...]) <= 0.05 + + +def test_modify_acq_info(random_kheader_shape): + """Test the modification of the acquisition info.""" + # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] + kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape + + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) + + # Verify shape + assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) + assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 06f55a4dc..3feb091de 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' diff --git a/tests/data/test_rotation.py b/tests/data/test_rotation.py index 6b4cbb52c..035a2d6c6 100644 --- a/tests/data/test_rotation.py +++ b/tests/data/test_rotation.py @@ -535,7 +535,7 @@ def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: assert torch.all(rms < rms_max) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -555,7 +555,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-15, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -576,7 +576,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-16, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -598,7 +598,7 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): torch.testing.assert_close(mat_expected, mat_estimated) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index cd46854b9..afafece04 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) -def test_spatial_dimension_apply(): - """Test apply (out of place)""" - - def conversion(x: torch.Tensor) -> torch.Tensor: - assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' - return x.swapaxes(0, 1).square() - - xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) - spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion) - - assert spatial_dimension_outofplace is not spatial_dimension - - assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) - - x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_outofplace.x, x) - assert torch.equal(spatial_dimension_outofplace.y, y) - assert torch.equal(spatial_dimension_outofplace.z, z) - - def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 1baf4340b..1061a93be 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -147,16 +147,16 @@ def test_trajectory_cpu(cartesian_grid): @COMMON_MR_TRAJECTORIES -def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the kz, ky and kz dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_kzyx[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) @@ -171,16 +171,16 @@ def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, @COMMON_MR_TRAJECTORIES -def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the k2, k1 and k0 dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_k210[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) diff --git a/tests/operators/functionals/conftest.py b/tests/operators/functionals/conftest.py index 9d2bfd8c9..ae8ffc20b 100644 --- a/tests/operators/functionals/conftest.py +++ b/tests/operators/functionals/conftest.py @@ -47,12 +47,12 @@ def result_dtype(self): def functional_test_cases(func: Callable[[FunctionalTestCase], None]) -> Callable[..., None]: """Decorator combining multiple parameterizations for test cases for all proximable functionals.""" - @pytest.mark.parametrize('shape', [[1, 2, 3]]) + @pytest.mark.parametrize('shape', [[1, 2, 3]], ids=['shape=[1,2,3]']) @pytest.mark.parametrize('dtype_name', ['float32', 'complex64']) @pytest.mark.parametrize('weight', ['scalar_weight', 'tensor_weight', 'complex_weight']) @pytest.mark.parametrize('target', ['no_target', 'random_target']) - @pytest.mark.parametrize('dim', [None]) - @pytest.mark.parametrize('divide_by_n', [True, False]) + @pytest.mark.parametrize('dim', [None], ids=['dim=None']) + @pytest.mark.parametrize('divide_by_n', [True, False], ids=['mean', 'sum']) @pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) def wrapper( functional: type[ElementaryProximableFunctional], diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 570fa2f1f..75fceacd2 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -8,7 +8,7 @@ SHAPE_VARIATIONS_SIGNAL_MODELS = pytest.mark.parametrize( ('parameter_shape', 'contrast_dim_shape', 'signal_shape'), [ - ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different inversion times + ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different contrast times ((1, 1, 10, 20, 30), (5, 1), (5, 1, 1, 10, 20, 30)), ((4, 1, 1, 10, 20, 30), (5, 1), (5, 4, 1, 1, 10, 20, 30)), # multiple maps along additional batch dimension ((4, 1, 1, 10, 20, 30), (5,), (5, 4, 1, 1, 10, 20, 30)), @@ -25,6 +25,24 @@ ((1,), (5,), (5, 1)), # single voxel ((4, 3, 1), (5, 4, 3), (5, 4, 3, 1)), ], + ids=[ + 'single_map_diff_contrast_times', + 'single_map_diff_contrast_times_2', + 'multiple_maps_additional_batch_dim', + 'multiple_maps_additional_batch_dim_2', + 'multiple_maps_additional_batch_dim_3', + 'multiple_maps_other_dim', + 'multiple_maps_other_dim_2', + 'multiple_maps_other_dim_3', + 'multiple_maps_other_and_batch_dim', + 'multiple_maps_other_and_batch_dim_2', + 'multiple_maps_other_and_batch_dim_3', + 'multiple_maps_other_and_batch_dim_4', + 'multiple_maps_other_and_batch_dim_5', + 'different_value_each_voxel', + 'single_voxel', + 'multiple_voxels', + ], ) diff --git a/tests/operators/models/test_inversion_recovery.py b/tests/operators/models/test_inversion_recovery.py index 7c36489bc..793ed2bf7 100644 --- a/tests/operators/models/test_inversion_recovery.py +++ b/tests/operators/models/test_inversion_recovery.py @@ -22,7 +22,7 @@ def test_inversion_recovery(ti, result): """ model = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) # Assert closeness to -m0 for ti=0 if result == '-m0': @@ -38,7 +38,7 @@ def test_inversion_recovery_shape(parameter_shape, contrast_dim_shape, signal_sh (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_molli.py b/tests/operators/models/test_molli.py index c459a9eb6..61b2b3d8f 100644 --- a/tests/operators/models/test_molli.py +++ b/tests/operators/models/test_molli.py @@ -28,7 +28,7 @@ def test_molli(ti, result): # Generate signal model and torch tensor for comparison model = MOLLI(ti) - (image,) = model.forward(a, c, t1) + (image,) = model(a, c, t1) # Assert closeness to a(1-c) for large ti if result == 'a(1-c)': @@ -44,7 +44,7 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape): (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MOLLI(ti) a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(a, c, t1) + (signal,) = model_op(a, c, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_mono_exponential_decay.py b/tests/operators/models/test_mono_exponential_decay.py index 5f81b7d0c..0072d6660 100644 --- a/tests/operators/models/test_mono_exponential_decay.py +++ b/tests/operators/models/test_mono_exponential_decay.py @@ -22,7 +22,7 @@ def test_mono_exponential_decay(decay_time, result): """ model = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples() - (image,) = model.forward(m0, decay_constant) + (image,) = model(m0, decay_constant) zeros = torch.zeros_like(m0) @@ -40,7 +40,7 @@ def test_mono_exponential_decay_shape(parameter_shape, contrast_dim_shape, signa (decay_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, decay_constant) + (signal,) = model_op(m0, decay_constant) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_saturation_recovery.py b/tests/operators/models/test_saturation_recovery.py index 6ab233fb1..6312bd33a 100644 --- a/tests/operators/models/test_saturation_recovery.py +++ b/tests/operators/models/test_saturation_recovery.py @@ -22,7 +22,7 @@ def test_saturation_recovery(ti, result): """ model = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) zeros = torch.zeros_like(m0) @@ -40,7 +40,7 @@ def test_saturation_recovery_shape(parameter_shape, contrast_dim_shape, signal_s (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index ef3563516..b1bd71c4c 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -24,7 +24,7 @@ def test_transient_steady_state(sampling_time, m0_scaling_preparation, result): repetition_time = 5 model = TransientSteadyStateWithPreparation(sampling_time, repetition_time, m0_scaling_preparation) m0, t1, flip_angle = create_parameter_tensor_tuples(number_of_tensors=3) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) # Assert closeness to m0 if result == 'm0': @@ -57,7 +57,7 @@ def test_transient_steady_state_inversion_recovery(): analytical_signal = m0 * (1 - 2 * torch.exp(-(sampling_time / t1))) model = TransientSteadyStateWithPreparation(sampling_time, repetition_time=100, m0_scaling_preparation=-1) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) torch.testing.assert_close(signal, analytical_signal) @@ -78,7 +78,7 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation ) m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(m0, t1, flip_angle) + (signal,) = model_op(m0, t1, flip_angle) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 97a448c3e..dda9268a1 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -15,11 +15,11 @@ def test_WASABI_shift(): """Test symmetry property of shifted and unshifted WASABI spectra.""" offsets_unshifted, b0_shift, rb1, c, d = create_data() wasabi_model = WASABI(offsets=offsets_unshifted) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) offsets_shifted, b0_shift, rb1, c, d = create_data(b0_shift=100) wasabi_model = WASABI(offsets=offsets_shifted) - (signal_shifted,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal_shifted,) = wasabi_model(b0_shift, rb1, c, d) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -31,7 +31,7 @@ def test_WASABI_shift(): def test_WASABI_extreme_offset(): offset, b0_shift, rb1, c, d = create_data(offset_max=30000, n_offsets=1) wasabi_model = WASABI(offsets=offset) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) assert torch.isclose(signal, torch.tensor([1.0])), 'For an extreme offset, the signal should be unattenuated' @@ -42,7 +42,7 @@ def test_WASABI_shape(parameter_shape, contrast_dim_shape, signal_shape): (offsets,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = WASABI(offsets) b0_shift, rb1, c, d = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=4) - (signal,) = model_op.forward(b0_shift, rb1, c, d) + (signal,) = model_op(b0_shift, rb1, c, d) assert signal.shape == signal_shape diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index ad536b51d..4201a35a0 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -16,7 +16,7 @@ def test_WASABITI_symmetry(): """Test symmetry property of complete WASABITI spectra.""" offsets, b0_shift, rb1, t1 = create_data() wasabiti_model = WASABITI(offsets=offsets, trec=torch.ones_like(offsets)) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) # check that all values are symmetric around the center assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center' @@ -27,7 +27,7 @@ def test_WASABITI_symmetry_after_shift(): offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100) trec = torch.ones_like(offsets_shifted) wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) - (signal_shifted,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) lower_index = (offsets_shifted == -300).nonzero()[0][0].item() upper_index = (offsets_shifted == 500).nonzero()[0][0].item() @@ -43,7 +43,7 @@ def test_WASABITI_asymmetry_for_non_unique_trec(): trec[: len(offsets_unshifted) // 2] = 2.0 wasabiti_model = WASABITI(offsets=offsets_unshifted, trec=trec) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center' @@ -54,7 +54,7 @@ def test_WASABITI_relaxation_term(t1): offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) trec = torch.ones_like(offset) * t1 wasabiti_model = WASABITI(offsets=offset, trec=trec) - sig = wasabiti_model.forward(b0_shift, rb1, t1) + sig = wasabiti_model(b0_shift, rb1, t1) assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8) @@ -73,7 +73,7 @@ def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) model_op = WASABITI(ti, trec) b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(b0_shift, rb1, t1) + (signal,) = model_op(b0_shift, rb1, t1) assert signal.shape == signal_shape diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..c1738b7bb 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,6 +2,7 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp @@ -16,10 +17,10 @@ def test_cart_sampling_op_data_match(): nkx = (1, 1, 1, 60) nky = (1, 1, 40, 1) nkz = (1, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + type_kx = 'uniform' + type_ky = 'uniform' + type_kz = 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Create matching data random_generator = RandomGenerator(seed=0) @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match(): 'regular_undersampling', 'random_undersampling', 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', ], ) def test_cart_sampling_op_fwd_adj(sampling): @@ -69,10 +73,10 @@ def test_cart_sampling_op_fwd_adj(sampling): nkx = (2, 1, 1, 60) nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() # Subsample data and trajectory match sampling: @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling): for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = torch.randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') @@ -105,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) +@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) +def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): + """Test trajectory points outside of encoding_matrix.""" + encoding_matrix = SpatialDimension(40, 1, 20) + + # Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side + # The indices are inverted to ensure CartesianSamplingOp acts on them + kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx') + ky = torch.ones(1, 1, 1, 1) + kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1') + kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements + trajectory = KTrajectory(kz=kz, ky=ky, kx=kx) + + with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'): + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + + random_generator = RandomGenerator(seed=0) + u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1])) + v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx)) + + assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx + assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1]) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 7ac1a8211..7e2f47fd7 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -28,7 +28,7 @@ def test_fast_fourier_op_forward(npoints, a): # Transform image to k-space ff_op = FastFourierOp(dim=(0,)) - (igauss_fwd,) = ff_op.forward(igauss) + (igauss_fwd,) = ff_op(igauss) # Scaling to "undo" fft scaling igauss_fwd *= np.sqrt(npoints) / 2 diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 89a4bbc11..2d76642c3 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,7 +1,9 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp from tests import RandomGenerator @@ -9,26 +11,32 @@ from tests.helper import dotproduct_adjointness_test -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @COMMON_MR_TRAJECTORIES -def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_fourier_op_fwd_adj_property( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test adjoint property of Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -42,29 +50,52 @@ def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, @pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ - # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), # Cartesian ky dimension defined along k2 rather than k1 - (5, 1, 18, 64), - 'nuf', - 'uf', - 'nuf', + ( # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 96, 1, 1), # nky - Cartesian ky dimension defined along k2 rather than k1 + (5, 1, 18, 64), # nkz + 'non-uniform', # type_kx + 'uniform', # type_ky + 'non-uniform', # type_kz ), ], + ids=['cartesian_fft_dims_not_aligned_with_k2_k1_k0_dims'], ) -def test_fourier_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Test trajectory not supported by Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = torch.randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted) diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py new file mode 100644 index 000000000..7ba87d715 --- /dev/null +++ b/tests/operators/test_linearoperatormatrix.py @@ -0,0 +1,322 @@ +from typing import Any + +import pytest +import torch +from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +def random_linearop(size, rng): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex64_tensor(size), '... i j, ... j -> ... i') + + +def random_linearoperatormatrix(size, inner_size, rng): + """Create a random LinearOperatorMatrix.""" + operators = [[random_linearop(inner_size, rng) for i in range(size[1])] for j in range(size[0])] + return LinearOperatorMatrix(operators) + + +def test_linearoperatormatrix_shape(): + """Test creation and shape of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert matrix.shape == (5, 3) + + +def test_linearoperatormatrix_add_matrix(): + """Test addition of two LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 + matrix2)(*vector) + expected = tuple(a + b for a, b in zip(matrix1(*vector), matrix2(*vector), strict=False)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_add_tensor_nonsquare(): + """Test failure of addition of tensor to non-square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + with pytest.raises(NotImplementedError, match='square'): + (matrix1 + other) + + +def test_linearoperatormatrix_add_tensor_square(): + """Add tensor to square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 3), (2, 2), rng) + other = rng.complex64_tensor(2) + vector = rng.complex64_tensor((3, 2)) + result = (matrix1 + other)(*vector) + expected = tuple((mv + other * v for mv, v in zip(matrix1(*vector), vector, strict=True))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_rmul(): + """Test post multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + vector = rng.complex64_tensor((3, 10)) + result = (other * matrix1)(*vector) + expected = tuple(other * el for el in matrix1(*vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_mul(): + """Test pre multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(10) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 * other)(*vector) + expected = matrix1(*(other * el for el in vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition(): + """Test composition of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 @ matrix2)(*vector) + expected = matrix1(*(matrix2(*vector))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition_mismatch(): + """Test composition with mismatching shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((4, 3), (3, 10), rng) + vector = rng.complex64_tensor((4, 10)) + with pytest.raises(ValueError, match='shapes do not match'): + (matrix1 @ matrix2)(*vector) + + +def test_linearoperatormatrix_adjoint(): + """Test adjointness of Adjoint.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + dotproduct_adjointness_test(Wrapper(), rng.complex64_tensor((3, 10)), rng.complex64_tensor((5, 3))) + + +def test_linearoperatormatrix_repr(): + """Test repr of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert 'LinearOperatorMatrix(shape=(5, 3)' in repr(matrix) + + +def test_linearoperatormatrix_getitem(): + """Test slicing of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + def check(actual, expected): + assert tuple(tuple(row) for row in actual) == tuple(tuple(row) for row in expected) + + sliced = matrix[1:3, 2] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[2:3] for row in matrix._operators[1:3]]) + + sliced = matrix[0] + assert sliced.shape == (1, 6) + check(sliced._operators, matrix._operators[:1]) + + sliced = matrix[..., 0] + assert sliced.shape == (12, 1) + check(sliced._operators, [row[:1] for row in matrix._operators]) + + sliced = matrix[1:6:2, (3, 4)] + assert sliced.shape == (3, 2) + check(sliced._operators, [[matrix._operators[i][j] for j in (3, 4)] for i in range(1, 6, 2)]) + + sliced = matrix[-2:-4:-1, -1] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[-1:] for row in matrix._operators[-2:-4:-1]]) + + sliced = matrix[5, 5] + assert isinstance(sliced, LinearOperator) + assert sliced == matrix._operators[5][5] + + +def test_linearoperatormatrix_getitem_error(): + """Test error when slicing with wrong indices.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + with pytest.raises(IndexError, match='Too many indices'): + matrix[1, 1, 1] + + with pytest.raises(IndexError, match='out of range'): + matrix[20] + with pytest.raises(IndexError, match='out of range'): + matrix[-20] + with pytest.raises(IndexError, match='out of range'): + matrix[1:100] + with pytest.raises(IndexError, match='out of range'): + matrix[(100, 1)] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., -20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 1:100] + with pytest.raises(IndexError, match='index type'): + matrix[..., 1.0] + + +def test_linearoperatormatrix_norm_rows(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((3, 1), (3, 10), rng) + vector = rng.complex64_tensor((1, 10)) + result = matrix.operator_norm(*vector) + expected = sum(row[0].operator_norm(vector[0], dim=None) ** 2 for row in matrix._operators) ** 0.5 + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_norm_cols(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((1, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = matrix.operator_norm(*vector) + expected = max(op.operator_norm(v, dim=None) for op, v in zip(matrix._operators[0], vector, strict=False)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parametrize('seed', [0, 1, 2, 3]) +def test_linearoperatormatrix_norm(seed): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(seed) + matrix = random_linearoperatormatrix((4, 2), (3, 10), rng) + vector = rng.complex64_tensor((2, 10)) + result = matrix.operator_norm(*vector) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + real = Wrapper().operator_norm(vector, dim=None) + + assert result >= real + + +def test_linearoperatormatrix_shorthand_vertical(): + """Test shorthand for vertical stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 10), rng) + op2 = random_linearop((4, 10), rng) + x1 = rng.complex64_tensor((10,)) + + matrix1 = op1 & op2 + assert matrix1.shape == (2, 1) + + actual = matrix1(x1) + expected = (*op1(x1), *op2(x1)) + torch.testing.assert_close(actual, expected) + + matrix2 = op2 & (matrix1 & op1) + assert matrix2.shape == (4, 1) + + matrix3 = matrix2 & matrix2 + assert matrix3.shape == (8, 1) + + actual = matrix3(x1) + expected = 2 * (*op2(x1), *matrix1(x1), *op1(x1)) + torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_shorthand_horizontal(): + """Test shorthand for horizontal stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 4), rng) + op2 = random_linearop((3, 2), rng) + x1 = rng.complex64_tensor((4,)) + x2 = rng.complex64_tensor((2,)) + x3 = rng.complex64_tensor((4,)) + x4 = rng.complex64_tensor((2,)) + + matrix1 = op1 | op2 + assert matrix1.shape == (1, 2) + + actual1 = matrix1(x1, x2) + expected1 = (op1(x1)[0] + op2(x2)[0],) + torch.testing.assert_close(actual1, expected1) + + matrix2 = op2 | (matrix1 | op1) + assert matrix2.shape == (1, 4) + + matrix3 = matrix2 | matrix2 + assert matrix3.shape == (1, 8) + + expected3 = (2 * (op2(x2)[0] + (matrix1(x3, x4)[0] + op1(x1)[0])),) + actual3 = matrix3(x2, x3, x4, x1, x2, x3, x4, x1) + torch.testing.assert_close(actual3, expected3) + + +def test_linearoperatormatrix_stacking_error(): + """Test error when stacking matrix operators with different shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 4), (3, 10), rng) + matrix2 = random_linearoperatormatrix((3, 2), (3, 10), rng) + matrix3 = random_linearoperatormatrix((2, 4), (3, 10), rng) + op = random_linearop((3, 10), rng) + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & matrix2 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | matrix3 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | op + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & op + + +def test_linearoperatormatrix_error_nonlinearop(): + """Test error if trying to create a LinearOperatorMatrix with non linear operator.""" + op: Any = [[MagnitudeOp()]] # Any is used to hide this error from mypy + with pytest.raises(ValueError, match='LinearOperator'): + LinearOperatorMatrix(op) + + +def test_linearoperatormatrix_error_inconsistent_shapes(): + """Test error if trying to create a LinearOperatorMatrix with inonsistent row lengths.""" + rng = RandomGenerator(0) + op = random_linearop((3, 4), rng) + with pytest.raises(ValueError, match='same length'): + LinearOperatorMatrix([[op, op], [op]]) + + +def test_linearoperatormatrix_from_diagonal(): + """Test creation of LinearOperatorMatrix from diagonal.""" + rng = RandomGenerator(0) + ops = [random_linearop((2, 4), rng) for _ in range(3)] + matrix = LinearOperatorMatrix.from_diagonal(*ops) + xs = rng.complex64_tensor((3, 4)) + actual = matrix(*xs) + expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) + torch.testing.assert_close(actual, expected) diff --git a/tests/operators/test_operator_norm.py b/tests/operators/test_operator_norm.py index d5ae39873..c6e13dcf9 100644 --- a/tests/operators/test_operator_norm.py +++ b/tests/operators/test_operator_norm.py @@ -12,9 +12,8 @@ def test_power_iteration_uses_stopping_criterion(): """Test if the power iteration stops if the absolute and relative tolerance are chosen high.""" - # callback function that should not be called because the power iteration - # should stop if the tolerances are set high - def callback(): + def callback(_): + """Callback function that should not be called, because the power iteration should stop.""" pytest.fail('The power iteration did not stop despite high atol and rtol!') random_generator = RandomGenerator(seed=0) diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..a600bcecb --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,50 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 3db888897..054c99402 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -15,6 +15,7 @@ ((2, 2, 4), '... a b->... (a b)', (2, 8), {'b': 4}), # flatten ((2), '... (a b) -> ... a b', (2, 1), {'b': 1}), # unflatten ], + ids=['swap_axes', 'flatten', 'unflatten'], ) def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype): """Test adjointness and shape.""" diff --git a/tests/operators/test_sensitivity_op.py b/tests/operators/test_sensitivity_op.py index 25efc5bee..1321498d8 100644 --- a/tests/operators/test_sensitivity_op.py +++ b/tests/operators/test_sensitivity_op.py @@ -90,7 +90,7 @@ def test_sensitivity_op_other_dim_compatibility_fail(n_other_csm, n_other_img): # Apply to n_other_img shape u = random_generator.complex64_tensor(size=(n_other_img, 1, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): - sensitivity_op.forward(u) + sensitivity_op(u) v = random_generator.complex64_tensor(size=(n_other_img, n_coils, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index a56635f5f..ce2d8855d 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -20,7 +20,7 @@ def test_zero_pad_op_content(): original_shape=tuple([original_shape[d] for d in padding_dimensions]), padded_shape=tuple([padded_shape[d] for d in padding_dimensions]), ) - (padded_data,) = zero_padding_op.forward(original_data) + (padded_data,) = zero_padding_op(original_data) # Compare overlapping region torch.testing.assert_close(original_data[:, 10:90, :, 50:150, :, :], padded_data[:, :, :, :, 95:145, :]) diff --git a/tests/utils/test_modify_acq_info.py b/tests/utils/test_modify_acq_info.py deleted file mode 100644 index 451303d02..000000000 --- a/tests/utils/test_modify_acq_info.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for modification of acquisition infos.""" - -from einops import rearrange -from mrpro.utils import modify_acq_info - - -def test_modify_acq_info(random_kheader_shape): - """Test the modification of the acquisition info.""" - # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] - kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape - - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) - - # Verify shape - assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index 60a0dc5e3..dd57b8feb 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,9 @@ """Tests for reshaping utilities.""" import torch -from mrpro.utils import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right + +from tests import RandomGenerator def test_broadcast_right(): @@ -12,7 +14,7 @@ def test_broadcast_right(): def test_unsqueeze_left(): - """Test unsqueeze left""" + """Test unsqueeze_left""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_left(tensor, 2) assert unsqueezed.shape == (1, 1, 1, 2, 3) @@ -20,8 +22,32 @@ def test_unsqueeze_left(): def test_unsqueeze_right(): - """Test unsqueeze right""" + """Test unsqueeze_right""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_right(tensor, 2) assert unsqueezed.shape == (1, 2, 3, 1, 1) assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_reduce_view(): + """Test reduce_view""" + + tensor = RandomGenerator(0).float32_tensor((1, 2, 3, 1, 1, 1)) + tensor = tensor.expand(1, 2, 3, 4, 1, 1).contiguous() # this cannot be removed + tensor = tensor.expand(7, 2, 3, 4, 5, 6) + + reduced_all = reduce_view(tensor) + assert reduced_all.shape == (1, 2, 3, 4, 1, 1) + assert torch.equal(reduced_all.expand_as(tensor), tensor) + + reduced_two = reduce_view(tensor, (0, -1)) + assert reduced_two.shape == (1, 2, 3, 4, 5, 1) + assert torch.equal(reduced_two.expand_as(tensor), tensor) + + reduced_one_neg = reduce_view(tensor, -1) + assert reduced_one_neg.shape == (7, 2, 3, 4, 5, 1) + assert torch.equal(reduced_one_neg.expand_as(tensor), tensor) + + reduced_one_pos = reduce_view(tensor, 0) + assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) + assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py new file mode 100644 index 000000000..a232a5366 --- /dev/null +++ b/tests/utils/test_unit_conversion.py @@ -0,0 +1,82 @@ +"""Tests of unit conversion.""" + +import numpy as np +import torch +from mrpro.utils.unit_conversion import ( + deg_to_rad, + lamor_frequency_to_magnetic_field, + m_to_mm, + magnetic_field_to_lamor_frequency, + mm_to_m, + ms_to_s, + rad_to_deg, + s_to_ms, +) + +from tests import RandomGenerator + + +def test_mm_to_m(): + """Verify mm to m conversion.""" + generator = RandomGenerator(seed=0) + mm_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(mm_to_m(mm_input), mm_input / 1000.0) + + +def test_m_to_mm(): + """Verify m to mm conversion.""" + generator = RandomGenerator(seed=0) + m_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(m_to_mm(m_input), m_input * 1000.0) + + +def test_ms_to_s(): + """Verify ms to s conversion.""" + generator = RandomGenerator(seed=0) + ms_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(ms_to_s(ms_input), ms_input / 1000.0) + + +def test_s_to_ms(): + """Verify s to ms conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(s_to_ms(s_input), s_input * 1000.0) + + +def test_rad_to_deg_tensor(): + """Verify radians to degree conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(rad_to_deg(s_input), torch.rad2deg(s_input)) + + +def test_deg_to_rad_tensor(): + """Verify degree to radians conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(deg_to_rad(s_input), torch.deg2rad(s_input)) + + +def test_rad_to_deg_float(): + """Verify radians to degree conversion.""" + assert rad_to_deg(np.pi / 2) == 90.0 + + +def test_deg_to_rad_float(): + """Verify degree to radians conversion.""" + assert deg_to_rad(180.0) == np.pi + + +def test_lamor_frequency_to_magnetic_field(): + """Verify conversion of lamor frequency to magnetic field.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + proton_lamor_frequency_at_3tesla = 127.74 * 1e6 + assert lamor_frequency_to_magnetic_field(proton_lamor_frequency_at_3tesla, proton_gyromagnetic_ratio) == 3.0 + + +def test_magnetic_field_to_lamor_frequency(): + """Verify conversion of magnetic field to lamor frequency.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + magnetic_field_strength = 3.0 + assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6