diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..0981abb00 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +.github +.vscode +binder +examples +docs +tests \ No newline at end of file diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index 9f02e8dcc..a55b7860d 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -1,4 +1,4 @@ -name: Publish to PyPI +name: Build and Deploy Package on: push: @@ -7,21 +7,31 @@ on: jobs: build-testpypi-package: - name: Modify version and build package for TestPyPI + name: Build Package for TestPyPI runs-on: ubuntu-latest outputs: version: ${{ steps.set_suffix.outputs.version }} suffix: ${{ steps.set_suffix.outputs.suffix }} + version_changed: ${{ steps.changes.outputs.version_changed }} steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.x" - - name: Set environment variable for version suffix + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Check if VERSION file is modified compared to main + uses: dorny/paths-filter@v3 + id: changes + with: + base: main + filters: | + version_changed: + - 'src/mrpro/VERSION' + + - name: Set Version Suffix id: set_suffix run: | VERSION=$(cat src/mrpro/VERSION) @@ -30,12 +40,12 @@ jobs: echo "suffix=$SUFFIX" >> $GITHUB_OUTPUT echo "version=$VERSION" >> $GITHUB_OUTPUT - - name: Build package with version suffix + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store TestPyPi distribution + - name: Upload TestPyPI Distribution Artifact uses: actions/upload-artifact@v4 with: name: testpypi-package-distribution @@ -46,6 +56,7 @@ jobs: needs: - build-testpypi-package runs-on: ubuntu-latest + if: needs.build-testpypi-package.outputs.version_changed == 'true' environment: name: testpypi @@ -55,7 +66,7 @@ jobs: id-token: write steps: - - name: Download TestPyPi distribution + - name: Download TestPyPI Distribution uses: actions/download-artifact@v4 with: name: testpypi-package-distribution @@ -68,7 +79,7 @@ jobs: verbose: true test-install-from-testpypi: - name: Test installation from TestPyPI + name: Test Installation from TestPyPI needs: - testpypi-deployment - build-testpypi-package @@ -83,17 +94,25 @@ jobs: run: | VERSION=${{ needs.build-testpypi-package.outputs.version }} SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }} - python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ + for i in {1..3}; do + if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then + echo "Package installed successfully." + break + else + echo "Attempt $i failed. Retrying in 10 seconds..." + sleep 10 + fi + done build-pypi-package: - name: Build package for PyPI + name: Build Package for PyPI runs-on: ubuntu-latest needs: - test-install-from-testpypi outputs: version: ${{ steps.get_version.outputs.version }} steps: - - name: Checkout code + - name: Checkout Code uses: actions/checkout@v4 - name: Set up Python @@ -101,22 +120,22 @@ jobs: with: python-version: "3.x" - - name: Install setuptools_git_versioning + - name: Install Automatic Versioning Tool run: | python -m pip install setuptools-git-versioning - - name: Get current version + - name: Get Current Version id: get_version run: | VERSION=$(python -m setuptools_git_versioning) echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - - name: Build package + - name: Build Package run: | python -m pip install --upgrade build python -m build - - name: Store PyPi distribution + - name: Store PyPI Distribution uses: actions/upload-artifact@v4 with: name: pypi-package-distribution @@ -138,13 +157,13 @@ jobs: id-token: write steps: - - name: Download PyPi distribution + - name: Download PyPI Distribution uses: actions/download-artifact@v4 with: name: pypi-package-distribution path: dist/ - - name: Create tag + - name: Create Tag uses: actions/github-script@v7 with: script: | @@ -155,7 +174,7 @@ jobs: sha: context.sha }) - - name: Create release + - name: Create Release uses: actions/github-script@v7 with: github-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b5e027cda..fe4f2ae03 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -62,6 +62,7 @@ jobs: permissions: packages: write strategy: + fail-fast: false matrix: dockerfile: ${{ fromJson(needs.get_dockerfiles.outputs.dockerfiles) }} steps: @@ -87,7 +88,9 @@ jobs: - name: Build and push Docker image uses: docker/build-push-action@v6 with: - context: ./docker + context: . + cache-from: type=gha,scope=${{ matrix.dockerfile }} + cache-to: type=gha,mode=max,scope=${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }} push: true tags: ${{ steps.image_name.outputs.image_name }}:test diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f102333b2..cc5ff5458 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -92,6 +92,7 @@ jobs: image: ghcr.io/ptb-mr/mrpro_py311:latest options: --user root strategy: + fail-fast: false matrix: notebook: ${{ fromJson(needs.get_notebooks.outputs.notebooks) }} steps: @@ -119,14 +120,14 @@ jobs: run: | notebook=${{ matrix.notebook }} echo "ARTIFACT_NAME=${notebook/.ipynb/}" >> $GITHUB_OUTPUT - echo "HTML_RESULT=${notebook/.ipynb/.html}" >> $GITHUB_OUTPUT + echo "IPYNB_EXECUTED=${notebook}" >> $GITHUB_OUTPUT - name: Upload notebook uses: actions/upload-artifact@v4 if: always() with: name: ${{ steps.artifact_names.outputs.ARTIFACT_NAME }} - path: ${{ github.workspace }}/nb-runner.out/${{ steps.artifact_names.outputs.HTML_RESULT }} + path: ${{ github.workspace }}/nb-runner.out/${{ steps.artifact_names.outputs.IPYNB_EXECUTED }} env: RUNNER: ${{ toJson(runner) }} @@ -149,39 +150,11 @@ jobs: - name: Install mrpro and dependencies run: pip install --upgrade --upgrade-strategy "eager" .[docs] - - name: Download notebook html files + - name: Download executed notebook ipynb files id: download uses: actions/download-artifact@v4 with: - path: ./docs/source/notebook_artifact/ - - - name: Copy notebook html files - run: | - mkdir ./docs/source/_notebooks - cd ./docs/source/notebook_artifact/ - notebooks=$(grep -rl --include='*' './') - for nb in $notebooks - do - echo "current jupyter-notebook: $nb" - cp ./$nb ../_notebooks/ - done - - - name: List of notebooks - run: | - cd ./docs/source/_notebooks/ - notebooks=$(grep -rl --include='*.html' './') - cd ../ - echo "" >> examples.rst - for nb in $notebooks - do - echo " notebook_${nb/.html/.rst}" >> examples.rst - notebook_description=$(grep '

\(.*\) "notebook_${nb/.html/.rst}" - echo "========" >> "notebook_${nb/.html/.rst}" - echo ".. raw:: html" >> "notebook_${nb/.html/.rst}" - echo " :file: ./_notebooks/$nb" >> "notebook_${nb/.html/.rst}" - done + path: ./docs/source/_notebooks/ - name: Build docs run: | @@ -194,7 +167,7 @@ jobs: with: name: Documentation path: docs/build/html/ - + - run: echo 'Artifact url ${{ steps.save_docu.outputs.artifact-url }}' - run: echo 'Event number ${{ github.event.number }}' @@ -224,7 +197,7 @@ jobs: deploy: if: github.ref == 'refs/heads/main' permissions: - pages: write + pages: write id-token: write environment: name: github-pages diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 633bf46dc..53d1343c7 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -2,17 +2,21 @@ name: PyTest on: pull_request: + push: + branches: + - main jobs: get_dockerfiles: - name: Get list of dockerfiles for different containers + name: Get List of Dockerfiles for Containers runs-on: ubuntu-latest permissions: packages: read outputs: imagenames: ${{ steps.set-matrix.outputs.imagenames }} steps: - - id: set-matrix + - name: Retrieve Docker Image Names + id: set-matrix env: GH_TOKEN: ${{ secrets.GHCR_TOKEN }} run: | @@ -35,57 +39,56 @@ jobs: echo "image names with tag latest: $imagenames_latest" echo "imagenames=$imagenames_latest" >> $GITHUB_OUTPUT - - name: Dockerfile overview + - name: Dockerfile Overview run: | - echo "final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" + echo "Final list of images with tag latest: ${{ steps.set-matrix.outputs.imagenames }}" test: - name: Run tests and get coverage report + name: Run Tests and Coverage Report needs: get_dockerfiles runs-on: ubuntu-latest permissions: pull-requests: write contents: write strategy: + fail-fast: false matrix: imagename: ${{ fromJson(needs.get_dockerfiles.outputs.imagenames) }} - # runs within Docker container container: image: ghcr.io/ptb-mr/${{ matrix.imagename }}:latest options: --user runner - steps: - - name: Checkout repo + - name: Checkout Repository uses: actions/checkout@v4 - - name: Install mrpro and dependencies - run: pip install --upgrade --upgrade-strategy "eager" .[test] + - name: Install MRpro and Dependencies + run: pip install --upgrade --upgrade-strategy eager .[test] - - name: Install pytest-github-actions-annotate-failures plugin + - name: Install PyTest GitHub Annotation Plugin run: pip install pytest-github-actions-annotate-failures - - name: Run PyTest + - name: Run PyTest and Generate Coverage Report run: | - pytest -n 4 -m "not cuda" --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt + pytest -n 4 -m "not cuda" --junitxml=pytest.xml \ + --cov-report=term-missing:skip-covered --cov=mrpro | tee pytest-coverage.txt - - name: Check for pytest.xml + - name: Verify PyTest XML Output run: | - if [ -f pytest.xml ]; then - echo "pytest.xml file found. Continuing..." - else - echo "pytest.xml file not found. Please check previous 'Run PyTest' section for errors." + if [ ! -f pytest.xml ]; then + echo "PyTest XML report not found. Please check the previous 'Run PyTest' step for errors." exit 1 fi - - name: Pytest coverage comment + - name: Post PyTest Coverage Comment id: coverageComment uses: MishaKav/pytest-coverage-comment@v1.1.53 with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml - - name: Create the Badge + - name: Create Coverage Badge on Main Branch Push uses: schneegans/dynamic-badges-action@v1.7.0 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' with: auth: ${{ secrets.GIST_SECRET }} gistID: 48e334a10caf60e6708d7c712e56d241 @@ -95,12 +98,12 @@ jobs: color: ${{ steps.coverageComment.outputs.color }} namedLogo: python - - name: Set pipeline status + - name: Set Pipeline Status Based on Test Results if: steps.coverageComment.outputs.errors != 0 || steps.coverageComment.outputs.failures != 0 uses: actions/github-script@v7 with: script: | - core.setFailed('PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.') + core.setFailed("PyTest workflow failed with ${{ steps.coverageComment.outputs.errors }} errors and ${{ steps.coverageComment.outputs.failures }} failures.") concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/.gitignore b/.gitignore index c1694cad9..29cb7951b 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/source/_notebooks/* # PyBuilder .pybuilder/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64ef2c7a2..2229c194f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,10 +3,9 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - - id: check-docstring-first - id: check-merge-conflict - id: check-yaml - id: check-toml @@ -15,19 +14,26 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.25.0 + rev: v1.27.0 hooks: - id: typos + - repo: https://github.com/fzimmermann89/check_all + rev: v1.1 + hooks: + - id: check-init-all + args: [--double-quotes, --fix] + exclude: ^tests/ + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy pass_filenames: false @@ -37,6 +43,7 @@ repos: - numpy - torch>=2.4.0 - types-requests + - typing-extensions - einops - pydicom - matplotlib @@ -44,12 +51,13 @@ repos: - xsdata - "--index-url=https://download.pytorch.org/whl/cpu" - "--extra-index-url=https://pypi.python.org/simple" + ci: - autofix_commit_msg: | - [pre-commit] auto fixes from pre-commit hooks - autofix_prs: false - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [mypy] - submodules: false \ No newline at end of file + autofix_commit_msg: | + [pre-commit] auto fixes from pre-commit hooks + autofix_prs: false + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate" + autoupdate_schedule: monthly + skip: [mypy] + submodules: false diff --git a/.vscode/extensions.json b/.vscode/extensions.json index a1373702d..e6ef04ea4 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -4,7 +4,7 @@ "davidanson.vscode-markdownlint", "editorconfig.editorconfig", "kevinrose.vsc-python-indent", - "ms-python.black-formatter", + "charliermarsh.ruff", "ms-python.isort", "ms-python.python", "ms-python.vscode-pylance", diff --git a/README.md b/README.md index e869fbfa7..422c4ef4d 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Quantitative parameter maps can be obtained by creating a functional to be minim # Define signal model model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti) # Define loss function and combine with signal model -mse = MSEDataDiscrepancy(idata_multi_ti.data.abs()) +mse = MSE(idata_multi_ti.data.abs()) functional = mse @ model [...] # Run optimization @@ -74,6 +74,8 @@ Full example: `_, +which we have also added to the list of extensions that VSCode should recommend when you open the code. +We also run `mypy `_ as a type checker. + +In CI, our linting is driven by `pre-commit `_. +If you install MRpro via ``pip install -e .[test]``, pre-commit will be installed in your python environment. +You can either add pre-commit to your git pre-commit hooks, requiring it to pass before each commit (``pre-commit install``), +or run it manually using ``pre-commit run --all-files`` after making your changes, before requesting a PR review. + Naming convention ================= -We try to follow the [pep8](https://peps.python.org/pep-0008/) naming convention (e.g., all lowercase variable names, +We try to follow the `PEP 8 `_ naming convention (e.g., all lowercase variable names, CapWords class names). We deviate for the names of source code file names containing a single class. These are named as the class. We try to use descriptive variable names when applicable (e.g., ``result`` instead of ``res``, ``tolerance_squared`` instead of ``sqtol``, ``batchsize`` instead of ``m``). -A name starting with ``n_`` is used for variables describing a number of... (e.g., ``n_coils`` instead of ``ncoils`` or +A name starting with ``n_`` is used for variables describing a number of something (e.g., ``n_coils`` instead of ``ncoils`` or ``num_coils``), variable names ending with ``_op`` for operators (e.g., ``fourier_op``). We use ``img`` as a variable name for images. +Testing +======= +We use pytest for testing. All required packages will be installed if you install MRpro via ``pip install -e .[test]``. +You can use VSCode's test panel to discover and run tests. All tests must pass before a PR can be merged. By default, we skip running CUDA tests. You can use ``pytest -m cuda`` to run the CUDA tests if your development machine has a GPU available. + +Building the Documentation +========================== +You can build the documentation locally via running ```make html``` in the docs folder. The documentation will also be build in each PR and can be viewed online. +Please check how your new additions render in the documentation before requesting a PR review. + +Adding new Examples +=================== +New exciting applications of MRpro can be added in ```examples``` as only ```.py``` files with code-cells. These can, for example, be used in VSCode with the python extension, or in JupyterLab with the `jupytext `_ extension. +An automatic workflow at github will create notebooks and pages in the documentation based on the python scripts. +The data to run the examples should be publicly available and hosted externally, for example at zenodo. +Please be careful not to add any binary files to your commits. + +Release Strategy +================ +We are still in pre-release mode and do not guarantee a stable API / strict semantic versioning compatibility. We currently use ```0.YYMMDD``` as versioning and release in regular intervals to `pypi `_. + +Compatibility +============= +We aim to always be compatible with the latest stable pytorch release and the latest python version supported by pytorch. We are compatible with one previous python version. +Our type hints will usually only be valid with the latest pytorch version. \ No newline at end of file diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 993517fd6..fe24a60a8 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -6,3 +6,7 @@ All of the notebooks can directly be run via binder or colab from the repo. .. toctree:: :maxdepth: 1 + :caption: Contents: + :glob: + + _notebooks/*/* diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index d0e6e0adb..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -76,7 +76,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_ismrmrd_traj = direct_reconstruction.forward(kdata)" + "img_using_ismrmrd_traj = direct_reconstruction(kdata)" ] }, { @@ -100,7 +100,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_rad2d_traj = direct_reconstruction.forward(kdata)" + "img_using_rad2d_traj = direct_reconstruction(kdata)" ] }, { @@ -143,7 +143,7 @@ "\n", "# Reconstruct image\n", "direct_reconstruction = DirectReconstruction(kdata)\n", - "img_using_pulseq_traj = direct_reconstruction.forward(kdata)" + "img_using_pulseq_traj = direct_reconstruction(kdata)" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 955b20a37..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -37,7 +37,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_ismrmrd_traj = direct_reconstruction.forward(kdata) +img_using_ismrmrd_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryRadial2D @@ -49,7 +49,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_rad2d_traj = direct_reconstruction.forward(kdata) +img_using_rad2d_traj = direct_reconstruction(kdata) # %% [markdown] # ### Image reconstruction using KTrajectoryPulseq @@ -73,7 +73,7 @@ # Reconstruct image direct_reconstruction = DirectReconstruction(kdata) -img_using_pulseq_traj = direct_reconstruction.forward(kdata) +img_using_pulseq_traj = direct_reconstruction(kdata) # %% [markdown] # ### Plot the different reconstructed images diff --git a/examples/qmri_sg_challenge_2024_t1.ipynb b/examples/qmri_sg_challenge_2024_t1.ipynb index 1832e5b48..e5605adfb 100644 --- a/examples/qmri_sg_challenge_2024_t1.ipynb +++ b/examples/qmri_sg_challenge_2024_t1.ipynb @@ -5,7 +5,7 @@ "id": "0f82262f", "metadata": {}, "source": [ - "# QMRI Challenge ISMRM 2024 - T1 mapping" + "# QMRI Challenge ISMRM 2024 - $T_1$ mapping" ] }, { @@ -29,7 +29,7 @@ "from mrpro.algorithms.optimizers import adam\n", "from mrpro.data import IData\n", "from mrpro.operators import MagnitudeOp\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import InversionRecovery" ] }, @@ -40,7 +40,7 @@ "source": [ "### Overview\n", "The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each\n", - "inversion time is saved in a separate DICOM file. In order to obtain a T1 map, we are going to:\n", + "inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to:\n", "- download the data from Zenodo\n", "- read in the DICOM files (one for each inversion time) and combine them in an IData object\n", "- define a signal model and data loss (mean-squared error) function\n", @@ -105,7 +105,7 @@ "fig, axes = plt.subplots(1, 3, squeeze=False)\n", "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :]))\n", - " ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.0f}ms')" + " ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s')" ] }, { @@ -116,9 +116,9 @@ "### Signal model and loss function\n", "We use the model $q$\n", "\n", - "$q(TI) = M_0 (1 - e^{-TI/T1})$\n", + "$q(TI) = M_0 (1 - e^{-TI/T_1})$\n", "\n", - "with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T1$. We have to keep in mind that the DICOM\n", + "with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T_1$. We have to keep in mind that the DICOM\n", "images only contain the magnitude of the signal. Therefore, we need $|q(TI)|$:" ] }, @@ -150,7 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse = MSEDataDiscrepancy(idata_multi_ti.data.abs())" + "mse = MSE(idata_multi_ti.data.abs())" ] }, { @@ -162,7 +162,7 @@ "source": [ "Now we can simply combine the two into a functional to solve\n", "\n", - "$ \\min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$" + "$ \\min_{M_0, T_1} || |q(M_0, T_1, TI)| - x||_2^2$" ] }, { @@ -187,11 +187,11 @@ "To increase our chances of reaching the global minimum, we can ensure that our starting\n", "values are already close to the global minimum. We need a good starting point for each pixel.\n", "\n", - "One option to get a good starting point is to calculate the signal curves for a range of T1 values and then check\n", + "One option to get a good starting point is to calculate the signal curves for a range of $T_1$ values and then check\n", "for each pixel which of these signal curves fits best. This is similar to what is done for MR Fingerprinting. So we\n", "are going to:\n", - "- define a list of realistic T1 values (we call this a dictionary of T1 values)\n", - "- calculate the signal curves corresponding to each of these T1 values\n", + "- define a list of realistic $T_1$ values (we call this a dictionary of $T_1$ values)\n", + "- calculate the signal curves corresponding to each of these $T_1$ values\n", "- compare the signal curves to the signals of each voxel (we use the maximum of the dot-product as a metric of how\n", "well the signals fit to each other)" ] @@ -203,8 +203,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Define 100 T1 values between 100 and 3000 ms\n", - "t1_dictionary = torch.linspace(100, 3000, 100)\n", + "# Define 100 T1 values between 0.1 and 3.0 s\n", + "t1_dictionary = torch.linspace(0.1, 3.0, 100)\n", "\n", "# Calculate the signal corresponding to each of these T1 values. We set M0 to 1, but this is arbitrary because M0 is\n", "# just a scaling factor and we are going to normalize the signal curves.\n", @@ -227,8 +227,8 @@ "metadata": {}, "outputs": [], "source": [ - "# The image with the longest inversion time is a good approximation of the equilibrium magnetization\n", - "m0_start = torch.abs(idata_multi_ti.data[torch.argmax(idata_multi_ti.header.ti), ...])" + "# The maximum absolute value observed is a good approximation for m0\n", + "m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0)" ] }, { @@ -242,11 +242,11 @@ "fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0_start[0, 0, ...])\n", - "axes[0, 0].set_title('M0 start values')\n", + "axes[0, 0].set_title('$M_0$ start values')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2500)\n", - "axes[0, 1].set_title('T1 start values')\n", - "fig.colorbar(im, cax=colorbar_ax[1])" + "im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 1].set_title('$T_1$ start values')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')" ] }, { @@ -266,7 +266,7 @@ "source": [ "# Hyperparameters for optimizer\n", "max_iter = 2000\n", - "lr = 1e0\n", + "lr = 1e-1\n", "\n", "# Run optimization\n", "params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)\n", @@ -283,7 +283,7 @@ "### Visualize the final results\n", "To get an impression of how well the fit has worked, we are going to calculate the relative error between\n", "\n", - "$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T1, TI) - x)|}{|x|}$\n", + "$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$\n", "\n", "on a voxel-by-voxel basis" ] @@ -304,11 +304,11 @@ "fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0[0, 0, ...])\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2500)\n", - "axes[0, 1].set_title('T1')\n", - "fig.colorbar(im, cax=colorbar_ax[1])\n", + "im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 1].set_title('$T_1$')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')\n", "im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0)\n", "axes[0, 2].set_title('Relative error')\n", "fig.colorbar(im, cax=colorbar_ax[2])" diff --git a/examples/qmri_sg_challenge_2024_t1.py b/examples/qmri_sg_challenge_2024_t1.py index 148f5bdd0..d0259f267 100644 --- a/examples/qmri_sg_challenge_2024_t1.py +++ b/examples/qmri_sg_challenge_2024_t1.py @@ -1,5 +1,5 @@ # %% [markdown] -# # QMRI Challenge ISMRM 2024 - T1 mapping +# # QMRI Challenge ISMRM 2024 - $T_1$ mapping # %% # Imports @@ -16,13 +16,13 @@ from mrpro.algorithms.optimizers import adam from mrpro.data import IData from mrpro.operators import MagnitudeOp -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import InversionRecovery # %% [markdown] # ### Overview # The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each -# inversion time is saved in a separate DICOM file. In order to obtain a T1 map, we are going to: +# inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to: # - download the data from Zenodo # - read in the DICOM files (one for each inversion time) and combine them in an IData object # - define a signal model and data loss (mean-squared error) function @@ -53,15 +53,15 @@ fig, axes = plt.subplots(1, 3, squeeze=False) for idx, ax in enumerate(axes.flatten()): ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :])) - ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.0f}ms') + ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s') # %% [markdown] # ### Signal model and loss function # We use the model $q$ # -# $q(TI) = M_0 (1 - e^{-TI/T1})$ +# $q(TI) = M_0 (1 - e^{-TI/T_1})$ # -# with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T1$. We have to keep in mind that the DICOM +# with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T_1$. We have to keep in mind that the DICOM # images only contain the magnitude of the signal. Therefore, we need $|q(TI)|$: # %% @@ -71,12 +71,12 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -mse = MSEDataDiscrepancy(idata_multi_ti.data.abs()) +mse = MSE(idata_multi_ti.data.abs()) # %% [markdown] # Now we can simply combine the two into a functional to solve # -# $ \min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$ +# $ \min_{M_0, T_1} || |q(M_0, T_1, TI)| - x||_2^2$ # %% functional = mse @ model @@ -88,17 +88,17 @@ # To increase our chances of reaching the global minimum, we can ensure that our starting # values are already close to the global minimum. We need a good starting point for each pixel. # -# One option to get a good starting point is to calculate the signal curves for a range of T1 values and then check +# One option to get a good starting point is to calculate the signal curves for a range of $T_1$ values and then check # for each pixel which of these signal curves fits best. This is similar to what is done for MR Fingerprinting. So we # are going to: -# - define a list of realistic T1 values (we call this a dictionary of T1 values) -# - calculate the signal curves corresponding to each of these T1 values +# - define a list of realistic $T_1$ values (we call this a dictionary of $T_1$ values) +# - calculate the signal curves corresponding to each of these $T_1$ values # - compare the signal curves to the signals of each voxel (we use the maximum of the dot-product as a metric of how # well the signals fit to each other) # %% -# Define 100 T1 values between 100 and 3000 ms -t1_dictionary = torch.linspace(100, 3000, 100) +# Define 100 T1 values between 0.1 and 3.0 s +t1_dictionary = torch.linspace(0.1, 3.0, 100) # Calculate the signal corresponding to each of these T1 values. We set M0 to 1, but this is arbitrary because M0 is # just a scaling factor and we are going to normalize the signal curves. @@ -114,19 +114,19 @@ t1_start = rearrange(t1_dictionary[idx_best_match], '(y x)->1 1 y x', y=n_y, x=n_x) # %% -# The image with the longest inversion time is a good approximation of the equilibrium magnetization -m0_start = torch.abs(idata_multi_ti.data[torch.argmax(idata_multi_ti.header.ti), ...]) +# The maximum absolute value observed is a good approximation for m0 +m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0) # %% # Visualize the starting values fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0_start[0, 0, ...]) -axes[0, 0].set_title('M0 start values') +axes[0, 0].set_title('$M_0$ start values') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2500) -axes[0, 1].set_title('T1 start values') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5) +axes[0, 1].set_title('$T_1$ start values') +fig.colorbar(im, cax=colorbar_ax[1], label='s') # %% [markdown] # ### Carry out fit @@ -134,7 +134,7 @@ # %% # Hyperparameters for optimizer max_iter = 2000 -lr = 1e0 +lr = 1e-1 # Run optimization params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr) @@ -146,7 +146,7 @@ # ### Visualize the final results # To get an impression of how well the fit has worked, we are going to calculate the relative error between # -# $E_{relative} = \sum_{TI}\frac{|(q(M_0, T1, TI) - x)|}{|x|}$ +# $E_{relative} = \sum_{TI}\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$ # # on a voxel-by-voxel basis @@ -158,11 +158,11 @@ fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, 0, ...]) -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2500) -axes[0, 1].set_title('T1') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5) +axes[0, 1].set_title('$T_1$') +fig.colorbar(im, cax=colorbar_ax[1], label='s') im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0) axes[0, 2].set_title('Relative error') fig.colorbar(im, cax=colorbar_ax[2]) diff --git a/examples/qmri_sg_challenge_2024_t2_star.ipynb b/examples/qmri_sg_challenge_2024_t2_star.ipynb index e2adcde8c..ad4122033 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.ipynb +++ b/examples/qmri_sg_challenge_2024_t2_star.ipynb @@ -5,7 +5,7 @@ "id": "5efa18e9", "metadata": {}, "source": [ - "# QMRI Challenge ISMRM 2024 - T2* mapping" + "# QMRI Challenge ISMRM 2024 - $T_2^*$ mapping" ] }, { @@ -28,7 +28,7 @@ "from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]\n", "from mrpro.algorithms.optimizers import adam\n", "from mrpro.data import IData\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import MonoExponentialDecay" ] }, @@ -39,7 +39,7 @@ "source": [ "### Overview\n", "The dataset consists of gradient echo images obtained at 11 different echo times, each saved in a separate DICOM file.\n", - "In order to obtain a T2* map, we are going to:\n", + "In order to obtain a $T_2^*$ map, we are going to:\n", "- download the data from Zenodo\n", "- read in the DICOM files (one for each echo time) and combine them in an IData object\n", "- define a signal model (mono-exponential decay) and data loss (mean-squared error) function\n", @@ -100,6 +100,8 @@ "source": [ "te_dicom_files = data_folder.glob('**/*.dcm')\n", "idata_multi_te = IData.from_dicom_files(te_dicom_files)\n", + "# scaling the signal down to make the optimization easier\n", + "idata_multi_te.data[...] = idata_multi_te.data / 1500\n", "\n", "# Move the data to the GPU\n", "if flag_use_cuda:\n", @@ -120,7 +122,7 @@ "fig, axes = plt.subplots(1, 3, squeeze=False)\n", "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(torch.abs(idata_multi_te.data[idx, 0, 0, :, :]).cpu())\n", - " ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.0f}ms')" + " ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.3f}s')" ] }, { @@ -131,9 +133,9 @@ "### Signal model and loss function\n", "We use the model $q$\n", "\n", - "$q(TE) = M_0 e^{-TE/T2^*}$\n", + "$q(TE) = M_0 e^{-TE/T_2^*}$\n", "\n", - "with the equilibrium magnetization $M_0$, the echo time $TE$, and $T2^*$" + "with the equilibrium magnetization $M_0$, the echo time $TE$, and $T_2^*$" ] }, { @@ -164,7 +166,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse = MSEDataDiscrepancy(idata_multi_te.data)" + "mse = MSE(idata_multi_te.data)" ] }, { @@ -176,7 +178,7 @@ "source": [ "Now we can simply combine the two into a functional which will then solve\n", "\n", - "$ \\min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$" + "$ \\min_{M_0, T_2^*} ||q(M_0, T_2^*, TE) - x||_2^2$" ] }, { @@ -207,11 +209,11 @@ "# The shortest echo time is a good approximation of the equilibrium magnetization\n", "m0_start = torch.abs(idata_multi_te.data[torch.argmin(idata_multi_te.header.te), ...])\n", "# 20 ms as a starting value for T2*\n", - "t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20\n", + "t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20e-3\n", "\n", "# Hyperparameters for optimizer\n", "max_iter = 20000\n", - "lr = 1e0\n", + "lr = 1e-3\n", "\n", "if flag_use_cuda:\n", " functional.cuda()\n", @@ -235,7 +237,7 @@ "### Visualize the final results\n", "To get an impression of how well the fit has worked, we are going to calculate the relative error between\n", "\n", - "$E_{relative} = \\sum_{TE}\\frac{|(q(M_0, T2^*, TE) - x)|}{|x|}$\n", + "$E_{relative} = \\sum_{TE}\\frac{|(q(M_0, T_2^*, TE) - x)|}{|x|}$\n", "\n", "on a voxel-by-voxel basis." ] @@ -257,12 +259,12 @@ "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "\n", "im = axes[0, 0].imshow(m0[0, 0, ...].cpu())\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", "\n", - "im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=500)\n", - "axes[0, 1].set_title('T2*')\n", - "fig.colorbar(im, cax=colorbar_ax[1])\n", + "im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=5)\n", + "axes[0, 1].set_title('$T_2^*$')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')\n", "\n", "im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...].cpu(), vmin=0, vmax=0.1)\n", "axes[0, 2].set_title('Relative error')\n", diff --git a/examples/qmri_sg_challenge_2024_t2_star.py b/examples/qmri_sg_challenge_2024_t2_star.py index ced49ae49..a80f40754 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.py +++ b/examples/qmri_sg_challenge_2024_t2_star.py @@ -1,5 +1,5 @@ # %% [markdown] -# # QMRI Challenge ISMRM 2024 - T2* mapping +# # QMRI Challenge ISMRM 2024 - $T_2^*$ mapping # %% # Imports @@ -15,13 +15,13 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] from mrpro.algorithms.optimizers import adam from mrpro.data import IData -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import MonoExponentialDecay # %% [markdown] # ### Overview # The dataset consists of gradient echo images obtained at 11 different echo times, each saved in a separate DICOM file. -# In order to obtain a T2* map, we are going to: +# In order to obtain a $T_2^*$ map, we are going to: # - download the data from Zenodo # - read in the DICOM files (one for each echo time) and combine them in an IData object # - define a signal model (mono-exponential decay) and data loss (mean-squared error) function @@ -48,6 +48,8 @@ # %% te_dicom_files = data_folder.glob('**/*.dcm') idata_multi_te = IData.from_dicom_files(te_dicom_files) +# scaling the signal down to make the optimization easier +idata_multi_te.data[...] = idata_multi_te.data / 1500 # Move the data to the GPU if flag_use_cuda: @@ -61,15 +63,15 @@ fig, axes = plt.subplots(1, 3, squeeze=False) for idx, ax in enumerate(axes.flatten()): ax.imshow(torch.abs(idata_multi_te.data[idx, 0, 0, :, :]).cpu()) - ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.0f}ms') + ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.3f}s') # %% [markdown] # ### Signal model and loss function # We use the model $q$ # -# $q(TE) = M_0 e^{-TE/T2^*}$ +# $q(TE) = M_0 e^{-TE/T_2^*}$ # -# with the equilibrium magnetization $M_0$, the echo time $TE$, and $T2^*$ +# with the equilibrium magnetization $M_0$, the echo time $TE$, and $T_2^*$ # %% model = MonoExponentialDecay(decay_time=idata_multi_te.header.te) @@ -78,12 +80,12 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -mse = MSEDataDiscrepancy(idata_multi_te.data) +mse = MSE(idata_multi_te.data) # %% [markdown] # Now we can simply combine the two into a functional which will then solve # -# $ \min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$ +# $ \min_{M_0, T_2^*} ||q(M_0, T_2^*, TE) - x||_2^2$ # %% functional = mse @ model @@ -94,11 +96,11 @@ # The shortest echo time is a good approximation of the equilibrium magnetization m0_start = torch.abs(idata_multi_te.data[torch.argmin(idata_multi_te.header.te), ...]) # 20 ms as a starting value for T2* -t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20 +t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20e-3 # Hyperparameters for optimizer max_iter = 20000 -lr = 1e0 +lr = 1e-3 if flag_use_cuda: functional.cuda() @@ -115,7 +117,7 @@ # ### Visualize the final results # To get an impression of how well the fit has worked, we are going to calculate the relative error between # -# $E_{relative} = \sum_{TE}\frac{|(q(M_0, T2^*, TE) - x)|}{|x|}$ +# $E_{relative} = \sum_{TE}\frac{|(q(M_0, T_2^*, TE) - x)|}{|x|}$ # # on a voxel-by-voxel basis. # %% @@ -127,12 +129,12 @@ colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, 0, ...].cpu()) -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=500) -axes[0, 1].set_title('T2*') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=5) +axes[0, 1].set_title('$T_2^*$') +fig.colorbar(im, cax=colorbar_ax[1], label='s') im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...].cpu(), vmin=0, vmax=0.1) axes[0, 2].set_title('Relative error') diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb new file mode 100644 index 000000000..6b1c2704b --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af432293", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Regularized Iterative SENSE Reconstruction of 2D golden angle radial data\n", + "Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7a6ce3", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# define zenodo URL of the example ismrmd data\n", + "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd8486b", + "metadata": {}, + "outputs": [], + "source": [ + "# Download raw data\n", + "import tempfile\n", + "\n", + "import requests\n", + "\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "6a9defa1", + "metadata": {}, + "source": [ + "### Image reconstruction\n", + "We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data.\n", + "RegularizedIterativeSENSEReconstruction solves the following reconstruction problem:\n", + "\n", + "Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms,\n", + "coil sensitivity maps...) $A$ then we can formulate the forward problem as:\n", + "\n", + "$ y = Ax + n $\n", + "\n", + "where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 $\n", + "\n", + "where $W^\\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal\n", + "operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain\n", + "a solution with certain properties:\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$\n", + "\n", + "where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image.\n", + "With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$.\n", + "\n", + "Setting the derivative of the functional $F$ to zero and rearranging yields\n", + "\n", + "$ (A^H W A + l B) x = A^H W y + l x_{reg}$\n", + "\n", + "which is a linear system $Hx = b$ that needs to be solved for $x$.\n", + "\n", + "One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution\n", + "dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks\n", + "has been used as an image regulariser.\n", + "\n", + "In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image.\n", + "Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using\n", + "only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the\n", + "regularization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4da15c2", + "metadata": {}, + "outputs": [], + "source": [ + "import mrpro" + ] + }, + { + "cell_type": "markdown", + "id": "de055070", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Read-in the raw data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ac1d89f", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.data import KData\n", + "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", + "\n", + "# Load in the Data and the trajectory from the ISMRMRD file\n", + "kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd())\n", + "kdata.header.recon_matrix.x = 256\n", + "kdata.header.recon_matrix.y = 256" + ] + }, + { + "cell_type": "markdown", + "id": "1f389140", + "metadata": {}, + "source": [ + "##### Image $x_{reg}$ from fully sampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212b915c", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction\n", + "from mrpro.data import CsmData\n", + "\n", + "# Estimate coil maps\n", + "direct_reconstruction = DirectReconstruction(kdata, csm=None)\n", + "img_coilwise = direct_reconstruction(kdata)\n", + "csm = CsmData.from_idata_walsh(img_coilwise)\n", + "\n", + "# Iterative SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3)\n", + "img_iterative_sense = iterative_sense_reconstruction(kdata)" + ] + }, + { + "cell_type": "markdown", + "id": "bec6b712", + "metadata": {}, + "source": [ + "##### Image $x$ from undersampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6740447", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Data undersampling, i.e. take only the first 20 radial lines\n", + "idx_us = torch.arange(0, 20)[None, :]\n", + "kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fbfd664", + "metadata": {}, + "outputs": [], + "source": [ + "# Iterativ SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6)\n", + "img_us_iterative_sense = iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041ffe72", + "metadata": {}, + "outputs": [], + "source": [ + "# Regularized iterativ SENSE reconstruction\n", + "from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction\n", + "\n", + "regularization_weight = 1.0\n", + "n_iterations = 6\n", + "regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction(\n", + " kdata_us,\n", + " csm=csm,\n", + " n_iterations=n_iterations,\n", + " regularization_data=img_iterative_sense.data,\n", + " regularization_weight=regularization_weight,\n", + ")\n", + "img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5bbec1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()]\n", + "vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4))\n", + "for ind in range(3):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "2bd49c87", + "metadata": {}, + "source": [ + "### Behind the scenes" + ] + }, + { + "cell_type": "markdown", + "id": "53779251", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", + "\n", + "This is very similar to the iterative SENSE reconstruction. For more detail please look at the\n", + "iterative_sense_reconstruction notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e985a4f3", + "metadata": {}, + "outputs": [], + "source": [ + "dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator()\n", + "fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us)\n", + "csm_operator = csm.as_operator()\n", + "acquisition_operator = fourier_operator @ csm_operator" + ] + }, + { + "cell_type": "markdown", + "id": "2daa0fee", + "metadata": {}, + "source": [ + "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac1d5fb4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "right_hand_side = (\n", + " acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76a0b153", + "metadata": {}, + "source": [ + "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5effb592", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.operators import IdentityOp\n", + "\n", + "operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor(\n", + " regularization_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f24a8588", + "metadata": {}, + "source": [ + "##### Run conjugate gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96827838", + "metadata": {}, + "outputs": [], + "source": [ + "img_manual = mrpro.algorithms.optimizers.cg(\n", + " operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c065a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the reconstructed image\n", + "vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]]\n", + "vis_title = ['Regularized Iterative SENSE R=20', '\"Manual\" Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4))\n", + "for ind in range(2):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "d6d7efdf", + "metadata": {}, + "source": [ + "### Check for equal results\n", + "The two versions should result in the same image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f59b6015", + "metadata": {}, + "outputs": [], + "source": [ + "# If the assert statement did not raise an exception, the results are equal.\n", + "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" + ] + }, + { + "cell_type": "markdown", + "id": "6ecd6e70", + "metadata": {}, + "source": [ + "### Next steps\n", + "Play around with the regularization_weight to see how it effects the final image quality.\n", + "\n", + "Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications\n", + "we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the\n", + "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", + "compared to the unregularised images." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py new file mode 100644 index 000000000..e41dc4ac5 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -0,0 +1,193 @@ +# %% [markdown] +# # Regularized Iterative SENSE Reconstruction of 2D golden angle radial data +# Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data +# %% +# define zenodo URL of the example ismrmd data +zenodo_url = 'https://zenodo.org/records/10854057/files/' +fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' +# %% +# Download raw data +import tempfile + +import requests + +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() + +# %% [markdown] +# ### Image reconstruction +# We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data. +# RegularizedIterativeSENSEReconstruction solves the following reconstruction problem: +# +# Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms, +# coil sensitivity maps...) $A$ then we can formulate the forward problem as: +# +# $ y = Ax + n $ +# +# where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$ +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 $ +# +# where $W^\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal +# operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain +# a solution with certain properties: +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$ +# +# where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image. +# With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$. +# +# Setting the derivative of the functional $F$ to zero and rearranging yields +# +# $ (A^H W A + l B) x = A^H W y + l x_{reg}$ +# +# which is a linear system $Hx = b$ that needs to be solved for $x$. +# +# One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution +# dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks +# has been used as an image regulariser. +# +# In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image. +# Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using +# only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the +# regularization. + +# %% +import mrpro + +# %% [markdown] +# ##### Read-in the raw data +# %% +from mrpro.data import KData +from mrpro.data.traj_calculators import KTrajectoryIsmrmrd + +# Load in the Data and the trajectory from the ISMRMRD file +kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd()) +kdata.header.recon_matrix.x = 256 +kdata.header.recon_matrix.y = 256 + +# %% [markdown] +# ##### Image $x_{reg}$ from fully sampled data + +# %% +from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction +from mrpro.data import CsmData + +# Estimate coil maps +direct_reconstruction = DirectReconstruction(kdata, csm=None) +img_coilwise = direct_reconstruction(kdata) +csm = CsmData.from_idata_walsh(img_coilwise) + +# Iterative SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3) +img_iterative_sense = iterative_sense_reconstruction(kdata) + +# %% [markdown] +# ##### Image $x$ from undersampled data + +# %% +import torch + +# Data undersampling, i.e. take only the first 20 radial lines +idx_us = torch.arange(0, 20)[None, :] +kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition') + +# %% +# Iterativ SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6) +img_us_iterative_sense = iterative_sense_reconstruction(kdata_us) + +# %% +# Regularized iterativ SENSE reconstruction +from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction + +regularization_weight = 1.0 +n_iterations = 6 +regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction( + kdata_us, + csm=csm, + n_iterations=n_iterations, + regularization_data=img_iterative_sense.data, + regularization_weight=regularization_weight, +) +img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us) + +# %% +import matplotlib.pyplot as plt + +vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()] +vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4)) +for ind in range(3): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + + +# %% [markdown] +# ### Behind the scenes + +# %% [markdown] +# ##### Set-up the density compensation operator $W$ and acquisition model $A$ +# +# This is very similar to the iterative SENSE reconstruction. For more detail please look at the +# iterative_sense_reconstruction notebook. +# %% +dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator() +fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us) +csm_operator = csm.as_operator() +acquisition_operator = fourier_operator @ csm_operator + +# %% [markdown] +# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ + +# %% +right_hand_side = ( + acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data +) + + +# %% [markdown] +# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ + +# %% +from mrpro.operators import IdentityOp + +operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor( + regularization_weight +) + +# %% [markdown] +# ##### Run conjugate gradient + +# %% +img_manual = mrpro.algorithms.optimizers.cg( + operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0 +) + +# %% +# Display the reconstructed image +vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]] +vis_title = ['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4)) +for ind in range(2): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + +# %% [markdown] +# ### Check for equal results +# The two versions should result in the same image data. + +# %% +# If the assert statement did not raise an exception, the results are equal. +assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) + +# %% [markdown] +# ### Next steps +# Play around with the regularization_weight to see how it effects the final image quality. +# +# Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications +# we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the +# streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality +# compared to the unregularised images. diff --git a/examples/ruff.toml b/examples/ruff.toml index 11a1e6167..1bb114755 100644 --- a/examples/ruff.toml +++ b/examples/ruff.toml @@ -5,4 +5,5 @@ lint.extend-ignore = [ "T20", #print "E402", #module-import-not-at-top-of-file "S101", #assert + "SIM115", #context manager for opening files ] diff --git a/examples/t1_mapping_with_grad_acq.ipynb b/examples/t1_mapping_with_grad_acq.ipynb index 743f7ad8e..cfe252e13 100644 --- a/examples/t1_mapping_with_grad_acq.ipynb +++ b/examples/t1_mapping_with_grad_acq.ipynb @@ -5,7 +5,7 @@ "id": "83bfb574", "metadata": {}, "source": [ - "# T1 mapping from a continuous Golden radial acquisition" + "# $T_1$ mapping from a continuous Golden radial acquisition" ] }, { @@ -29,42 +29,95 @@ "from mrpro.data import KData\n", "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", "from mrpro.operators import ConstraintsOp, MagnitudeOp\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import TransientSteadyStateWithPreparation\n", "from mrpro.utils import split_idx" ] }, { "cell_type": "markdown", - "id": "7f7c1229", - "metadata": {}, + "id": "29eabc2a", + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "### Overview\n", "In this acquisition, a single inversion pulse is played out, followed by a continuous data acquisition with a\n", "a constant flip angle $\\alpha$. Data acquisition is carried out with a 2D Golden angle radial trajectory. The acquired\n", "data can be divided into different dynamic time frames, each corresponding to a different inversion time. A signal\n", - "model can then be fitted to this data to obtain a $T_1$ map. More information can be found in:\n", - "\n", - "Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023) Motion-corrected model-based reconstruction for 2D\n", - "myocardial T1 mapping, MRM 90 https://doi.org/10.1002/mrm.29699\n", + "model can then be fitted to this data to obtain a $T_1$ map.\n", "\n", + "More information can be found in:\n", + "Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023). Motion-corrected model-based reconstruction for 2D\n", + "myocardial $T_1$ mapping. *Magnetic Resonance in Medicine*, 90(3):1086-1100, [10.1002/mrm.29699](https://doi.org/10.1002/mrm.29699)" + ] + }, + { + "cell_type": "markdown", + "id": "2f2c110e", + "metadata": {}, + "source": [ "The number of time frames and hence the number of radial lines per time frame, can in principle be chosen arbitrarily.\n", "However, a tradeoff between image quality (more radial lines per dynamic) and\n", - "temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found.\n", - "\n", + "temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found." + ] + }, + { + "cell_type": "markdown", + "id": "1ed1fc05", + "metadata": {}, + "source": [ "During data acquisition, the magnetization $M_z(t)$ can be described by the signal model:\n", - " $$ M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \\quad (1) $$\n", + "\n", + "$$\n", + " M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \\quad (1)\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "8b1e3c2f", + "metadata": {}, + "source": [ "where the effective longitudinal relaxation time is given by:\n", - " $$ T_1^* = \\frac{1}{\\frac{1}{T1} - \\frac{1}{T_R} ln(cos(\\alpha))} $$\n", + "\n", + "$$\n", + " T_1^* = \\frac{1}{\\frac{1}{T_1} - \\frac{1}{T_R} \\ln(\\cos(\\alpha))}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "1c6c6616", + "metadata": {}, + "source": [ "and the steady-state magnetization is\n", - " $$ M_0^* = M_0 \\frac{T_1^*}{T_1} .$$\n", "\n", + "$$\n", + " M_0^* = M_0 \\frac{T_1^*}{T_1} .\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "52b8c555", + "metadata": {}, + "source": [ "The initial magnetization $M_0^{init}$ after an inversion pulse is $-M_0$. Nevertheless, commonly after an inversion\n", "pulse, a strong spoiler gradient is played out to remove any residual transversal magnetization due to\n", "imperfections of the inversion pulse. During the spoiler gradient, the magnetization recovers with $T_1$. Commonly,\n", "the duration of this spoiler gradient $\\Delta t$ is between 10 to 20 ms. This leads to the initial magnetization\n", - " $$ M_0^{init} = M_0(1 - 2e^{(-\\Delta t / T_1)}) .$$\n", "\n", + "$$\n", + " M_0^{init} = M_0(1 - 2e^{(-\\Delta t / T_1)}) .\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "7f7c1229", + "metadata": {}, + "source": [ "In this example, we are going to:\n", "- Reconstruct a single high quality image using all acquired radial lines.\n", "- Split the data into multiple dynamics and reconstruct these dynamic images\n", @@ -82,7 +135,7 @@ "source": [ "# Download raw data in ISMRMRD format from zenodo into a temporary directory\n", "data_folder = Path(tempfile.mkdtemp())\n", - "dataset = '10671597'\n", + "dataset = '13207352'\n", "zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries" ] }, @@ -182,7 +235,7 @@ "id": "87260553", "metadata": {}, "source": [ - "## Estimate T1 map" + "## Estimate $T_1$ map" ] }, { @@ -223,9 +276,9 @@ "source": [ "We also need the repetition time between two RF-pulses. There is a parameter `tr` in the header, but this describes\n", "the time \"between the beginning of a pulse sequence and the beginning of the succeeding (essentially identical) pulse\n", - "sequence\" (see https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080). We have one inversion pulse at the\n", - "beginning, which is never repeated and hence `tr` is the duration of the entire scan. Therefore, we have to use the\n", - "parameter `echo_spacing`, which describes the time between two gradient echoes." + "sequence\" (see [DICOM Standard Browser](https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080)). We have one\n", + "inversion pulse at the beginning, which is never repeated and hence `tr` is the duration of the entire scan.\n", + "Therefore, we have to use the parameter `echo_spacing`, which describes the time between two gradient echoes." ] }, { @@ -317,7 +370,7 @@ }, "source": [ "### Loss function\n", - "As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal\n", + "As a loss function for the optimizer, we calculate the mean squared error between the image data $x$ and our signal\n", "model $q$." ] }, @@ -328,7 +381,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse_loss = MSEDataDiscrepancy(img_rss_dynamic)" + "mse_loss = MSE(img_rss_dynamic)" ] }, { @@ -340,7 +393,9 @@ "source": [ "Now we can simply combine the loss function, the signal model and the constraints to solve\n", "\n", - "$$ \\min_{M_0, T_1, \\alpha} || |q(M_0, T_1, \\alpha)| - x||_2^2$$" + "$$\n", + " \\min_{M_0, T_1, \\alpha} || |q(M_0, T_1, \\alpha)| - x||_2^2\n", + "$$" ] }, { @@ -406,10 +461,10 @@ "fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0[0, ...].abs(), cmap='gray')\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", "im = axes[0, 1].imshow(t1[0, ...], vmin=0, vmax=2)\n", - "axes[0, 1].set_title('T1 (s)')\n", + "axes[0, 1].set_title('$T_1$ (s)')\n", "fig.colorbar(im, cax=colorbar_ax[1])\n", "im = axes[0, 2].imshow(flip_angle[0, ...] / torch.pi * 180, vmin=0, vmax=8)\n", "axes[0, 2].set_title('Flip angle (°)')\n", diff --git a/examples/t1_mapping_with_grad_acq.py b/examples/t1_mapping_with_grad_acq.py index 29c08b031..de8e31c43 100644 --- a/examples/t1_mapping_with_grad_acq.py +++ b/examples/t1_mapping_with_grad_acq.py @@ -1,5 +1,5 @@ # %% [markdown] -# # T1 mapping from a continuous Golden radial acquisition +# # $T_1$ mapping from a continuous Golden radial acquisition # %% # Imports @@ -16,7 +16,7 @@ from mrpro.data import KData from mrpro.data.traj_calculators import KTrajectoryIsmrmrd from mrpro.operators import ConstraintsOp, MagnitudeOp -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import TransientSteadyStateWithPreparation from mrpro.utils import split_idx @@ -25,28 +25,50 @@ # In this acquisition, a single inversion pulse is played out, followed by a continuous data acquisition with a # a constant flip angle $\alpha$. Data acquisition is carried out with a 2D Golden angle radial trajectory. The acquired # data can be divided into different dynamic time frames, each corresponding to a different inversion time. A signal -# model can then be fitted to this data to obtain a $T_1$ map. More information can be found in: -# -# Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023) Motion-corrected model-based reconstruction for 2D -# myocardial T1 mapping, MRM 90 https://doi.org/10.1002/mrm.29699 +# model can then be fitted to this data to obtain a $T_1$ map. # +# More information can be found in: +# Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023). Motion-corrected model-based reconstruction for 2D +# myocardial $T_1$ mapping. *Magnetic Resonance in Medicine*, 90(3):1086-1100, [10.1002/mrm.29699](https://doi.org/10.1002/mrm.29699) + + +# %% [markdown] # The number of time frames and hence the number of radial lines per time frame, can in principle be chosen arbitrarily. # However, a tradeoff between image quality (more radial lines per dynamic) and # temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found. -# + +# %% [markdown] # During data acquisition, the magnetization $M_z(t)$ can be described by the signal model: -# $$ M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \quad (1) $$ +# +# $$ +# M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \quad (1) +# $$ + +# %% [markdown] # where the effective longitudinal relaxation time is given by: -# $$ T_1^* = \frac{1}{\frac{1}{T1} - \frac{1}{T_R} ln(cos(\alpha))} $$ +# +# $$ +# T_1^* = \frac{1}{\frac{1}{T_1} - \frac{1}{T_R} \ln(\cos(\alpha))} +# $$ + +# %% [markdown] # and the steady-state magnetization is -# $$ M_0^* = M_0 \frac{T_1^*}{T_1} .$$ # +# $$ +# M_0^* = M_0 \frac{T_1^*}{T_1} . +# $$ + +# %% [markdown] # The initial magnetization $M_0^{init}$ after an inversion pulse is $-M_0$. Nevertheless, commonly after an inversion # pulse, a strong spoiler gradient is played out to remove any residual transversal magnetization due to # imperfections of the inversion pulse. During the spoiler gradient, the magnetization recovers with $T_1$. Commonly, # the duration of this spoiler gradient $\Delta t$ is between 10 to 20 ms. This leads to the initial magnetization -# $$ M_0^{init} = M_0(1 - 2e^{(-\Delta t / T_1)}) .$$ # +# $$ +# M_0^{init} = M_0(1 - 2e^{(-\Delta t / T_1)}) . +# $$ + +# %% [markdown] # In this example, we are going to: # - Reconstruct a single high quality image using all acquired radial lines. # - Split the data into multiple dynamics and reconstruct these dynamic images @@ -55,7 +77,7 @@ # %% # Download raw data in ISMRMRD format from zenodo into a temporary directory data_folder = Path(tempfile.mkdtemp()) -dataset = '10671597' +dataset = '13207352' zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries @@ -105,7 +127,7 @@ cax.set_title(f'Dynamic {idx}') # %% [markdown] -# ## Estimate T1 map +# ## Estimate $T_1$ map # %% [markdown] # ### Signal model @@ -129,9 +151,9 @@ # %% [markdown] # We also need the repetition time between two RF-pulses. There is a parameter `tr` in the header, but this describes # the time "between the beginning of a pulse sequence and the beginning of the succeeding (essentially identical) pulse -# sequence" (see https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080). We have one inversion pulse at the -# beginning, which is never repeated and hence `tr` is the duration of the entire scan. Therefore, we have to use the -# parameter `echo_spacing`, which describes the time between two gradient echoes. +# sequence" (see [DICOM Standard Browser](https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080)). We have one +# inversion pulse at the beginning, which is never repeated and hence `tr` is the duration of the entire scan. +# Therefore, we have to use the parameter `echo_spacing`, which describes the time between two gradient echoes. # %% if kdata_dynamic.header.echo_spacing is None: @@ -173,15 +195,17 @@ # %% [markdown] # ### Loss function -# As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal +# As a loss function for the optimizer, we calculate the mean squared error between the image data $x$ and our signal # model $q$. # %% -mse_loss = MSEDataDiscrepancy(img_rss_dynamic) +mse_loss = MSE(img_rss_dynamic) # %% [markdown] # Now we can simply combine the loss function, the signal model and the constraints to solve # -# $$ \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2$$ +# $$ +# \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2 +# $$ # %% functional = mse_loss @ magnitude_model_op @ constraints_op @@ -212,10 +236,10 @@ fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, ...].abs(), cmap='gray') -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) im = axes[0, 1].imshow(t1[0, ...], vmin=0, vmax=2) -axes[0, 1].set_title('T1 (s)') +axes[0, 1].set_title('$T_1$ (s)') fig.colorbar(im, cax=colorbar_ax[1]) im = axes[0, 2].imshow(flip_angle[0, ...] / torch.pi * 180, vmin=0, vmax=8) axes[0, 2].set_title('Flip angle (°)') diff --git a/pyproject.toml b/pyproject.toml index 606893d7b..6f5d09048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ include-package-data = true name = "mrpro" description = "MR image reconstruction and processing package specifically developed for PyTorch." readme = "README.md" -requires-python = ">=3.11,<3.14" +requires-python = ">=3.10,<3.14" dynamic = ["version"] keywords = ["MRI, reconstruction, processing, PyTorch"] authors = [ @@ -34,10 +34,12 @@ authors = [ { name = "Johannes Hammacher", email = "johannnes.hammacher@ptb.de" }, { name = "Stefan Martin", email = "stefan.martin@ptb.de" }, { name = "Andreas Kofler", email = "andreas.kofler@ptb.de" }, + { name = "Catarina Redshaw Kranich", email = "catarina.redshaw-kranich@ptb.de" }, ] classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3 :: Only", @@ -52,6 +54,7 @@ dependencies = [ "torchkbnufft>=1.4.0", "scipy>=1.12", "ptwt>=0.1.8", + "typing-extensions>=4.12", ] [project.optional-dependencies] @@ -63,7 +66,12 @@ test = [ "pytest-cov", "pytest-xdist", ] -docs = ["sphinx", "sphinx_rtd_theme", "sphinx-pyproject"] +docs = ["sphinx", + "sphinx_rtd_theme", + "sphinx-pyproject", + "myst-nb", + "sphinx-mathjax-offline", + ] notebook = [ "zenodo_get", "ipykernel", @@ -82,6 +90,7 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", + "ignore:Anomaly Detection has been enabled:UserWarning", #torch.autograd ] addopts = "-n auto" markers = ["cuda : Tests only to be run when cuda device is available"] @@ -174,12 +183,14 @@ skip-magic-trailing-comma = false [tool.typos.default] locale = "en-us" +check-filename = false [tool.typos.default.extend-words] -Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med. +Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med. iy = "iy" -daa = 'daa' # required for wavelet operator -gaus = 'gaus' # required for wavelet operator +daa = "daa" # required for wavelet operator +gaus = "gaus" # required for wavelet operator +arange = "arange" # torch.arange [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0a67c464f..c60039027 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241015 +0.241112 diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 4a0ac4ca0..729ae188c 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,4 +1,10 @@ from mrpro._version import __version__ from mrpro import algorithms, operators, data, phantoms, utils -__all__ = ["algorithms", "operators", "data", "phantoms", "utils"] - +__all__ = [ + "__version__", + "algorithms", + "data", + "operators", + "phantoms", + "utils" +] diff --git a/src/mrpro/algorithms/__init__.py b/src/mrpro/algorithms/__init__.py index 6a5d8cd64..3dbc8ed07 100644 --- a/src/mrpro/algorithms/__init__.py +++ b/src/mrpro/algorithms/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms import csm, optimizers, reconstruction from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -__all__ = ["csm", "optimizers", "reconstruction", "prewhiten_kspace"] +__all__ = ["csm", "optimizers", "prewhiten_kspace", "reconstruction"] \ No newline at end of file diff --git a/src/mrpro/algorithms/csm/__init__.py b/src/mrpro/algorithms/csm/__init__.py index 499984f53..255e09a89 100644 --- a/src/mrpro/algorithms/csm/__init__.py +++ b/src/mrpro/algorithms/csm/__init__.py @@ -1,3 +1,3 @@ from mrpro.algorithms.csm.walsh import walsh from mrpro.algorithms.csm.inati import inati -__all__ = ["walsh", "inati"] +__all__ = ["inati", "walsh"] \ No newline at end of file diff --git a/src/mrpro/algorithms/optimizers/OptimizerStatus.py b/src/mrpro/algorithms/optimizers/OptimizerStatus.py index b682cf380..064bedb82 100644 --- a/src/mrpro/algorithms/optimizers/OptimizerStatus.py +++ b/src/mrpro/algorithms/optimizers/OptimizerStatus.py @@ -1,8 +1,7 @@ """Optimizer Status base class.""" -from typing import TypedDict - import torch +from typing_extensions import TypedDict class OptimizerStatus(TypedDict): diff --git a/src/mrpro/algorithms/optimizers/adam.py b/src/mrpro/algorithms/optimizers/adam.py index ed145b41b..bbf6eeac3 100644 --- a/src/mrpro/algorithms/optimizers/adam.py +++ b/src/mrpro/algorithms/optimizers/adam.py @@ -6,11 +6,11 @@ from torch.optim import Adam, AdamW from mrpro.algorithms.optimizers.OptimizerStatus import OptimizerStatus -from mrpro.operators.Operator import Operator +from mrpro.operators.Operator import OperatorType def adam( - f: Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor]], + f: OperatorType, initial_parameters: Sequence[torch.Tensor], max_iter: int, lr: float = 1e-3, diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 54cb1d782..2d458bfa0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -81,13 +81,8 @@ def cg( if torch.vdot(residual.flatten(), residual.flatten()) == 0: return solution - # squared tolerance; - # (we will check ||residual||^2 < tolerance^2 instead of ||residual|| < tol - # to avoid the computation of the root for the norm) - tolerance_squared = tolerance**2 - # dummy value. new value will be set in loop before first usage - residual_norm_squared_previous = None + residual_norm_squared_previous: torch.Tensor | None = None for iteration in range(max_iterations): # calculate the square norm of the residual @@ -95,18 +90,18 @@ def cg( residual_norm_squared = torch.vdot(residual_flat, residual_flat).real # check if the solution is already accurate enough - if tolerance != 0 and (residual_norm_squared < tolerance_squared): + if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution - if iteration > 0: + if residual_norm_squared_previous is not None: # not first iteration beta = residual_norm_squared / residual_norm_squared_previous conjugate_vector = residual + beta * conjugate_vector # update estimates of the solution and the residual (operator_conjugate_vector,) = operator(conjugate_vector) alpha = residual_norm_squared / (torch.vdot(conjugate_vector.flatten(), operator_conjugate_vector.flatten())) - solution += alpha * conjugate_vector - residual -= alpha * operator_conjugate_vector + solution = solution + alpha * conjugate_vector + residual = residual - alpha * operator_conjugate_vector residual_norm_squared_previous = residual_norm_squared diff --git a/src/mrpro/algorithms/optimizers/lbfgs.py b/src/mrpro/algorithms/optimizers/lbfgs.py index b6825d88e..736dcc0b0 100644 --- a/src/mrpro/algorithms/optimizers/lbfgs.py +++ b/src/mrpro/algorithms/optimizers/lbfgs.py @@ -7,11 +7,11 @@ from torch.optim import LBFGS from mrpro.algorithms.optimizers.OptimizerStatus import OptimizerStatus -from mrpro.operators.Operator import Operator +from mrpro.operators.Operator import OperatorType def lbfgs( - f: Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor]], + f: OperatorType, initial_parameters: Sequence[torch.Tensor], lr: float = 1.0, max_iter: int = 100, diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 04d37128c..32785a91a 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -4,20 +4,17 @@ from collections.abc import Callable -import torch - -from mrpro.algorithms.optimizers.cg import cg -from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( + RegularizedIterativeSENSEReconstruction, +) from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData -from mrpro.data.IData import IData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator -class IterativeSENSEReconstruction(DirectReconstruction): +class IterativeSENSEReconstruction(RegularizedIterativeSENSEReconstruction): r"""Iterative SENSE reconstruction. This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2` @@ -49,6 +46,8 @@ def __init__( ) -> None: """Initialize IterativeSENSEReconstruction. + For a regularized version of the iterative SENSE algorithm please see RegularizedIterativeSENSEReconstruction. + Parameters ---------- kdata @@ -74,76 +73,4 @@ def __init__( ValueError If the kdata and fourier_op are None or if csm is a Callable but kdata is None. """ - super().__init__(kdata, fourier_op, csm, noise, dcf) - self.n_iterations = n_iterations - - def _self_adjoint_operator(self) -> LinearOperator: - """Create the self-adjoint operator. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Create the self-adjoint operator as :math:`H = A^H W A` if the DCF is not None otherwise as :math:`H = A^H A`. - """ - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Create H = A^H W A - operator = operator.H @ dcf_operator @ operator - else: - # Create H = A^H A - operator = operator.H @ operator - - return operator - - def _right_hand_side(self, kdata: KData) -> torch.Tensor: - """Calculate the right-hand-side of the normal equation. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Calculate the right-hand-side as :math:`b = A^H W y` if the DCF is not None otherwise as :math:`b = A^H y`. - - Parameters - ---------- - kdata - k-space data to reconstruct. - """ - device = kdata.data.device - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Calculate b = A^H W y - (right_hand_side,) = operator.to(device).H(dcf_operator(kdata.data)[0]) - else: - # Calculate b = A^H y - (right_hand_side,) = operator.to(device).H(kdata.data) - - return right_hand_side - - def forward(self, kdata: KData) -> IData: - """Apply the reconstruction. - - Parameters - ---------- - kdata - k-space data to reconstruct. - - Returns - ------- - the reconstruced image. - """ - device = kdata.data.device - if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) - - operator = self._self_adjoint_operator().to(device) - right_hand_side = self._right_hand_side(kdata) - - img_tensor = cg( - operator, right_hand_side, initial_value=right_hand_side, max_iterations=self.n_iterations, tolerance=0.0 - ) - img = IData.from_tensor_and_kheader(img_tensor, kdata.header) - return img + super().__init__(kdata, fourier_op, csm, noise, dcf, n_iterations=n_iterations, regularization_weight=0) diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index f26e92737..c4208157e 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Literal, Self +from typing import Literal import torch +from typing_extensions import Self from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace from mrpro.data._kdata.KData import KData @@ -100,15 +101,13 @@ def direct_reconstruction(self, kdata: KData) -> IData: ------- image data """ - device = kdata.data.device if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) + kdata = prewhiten_kspace(kdata, self.noise) operator = self.fourier_op if self.csm is not None: operator = operator @ self.csm.as_operator() if self.dcf is not None: operator = self.dcf.as_operator() @ operator - operator = operator.to(device) (img_tensor,) = operator.H(kdata.data) img = IData.from_tensor_and_kheader(img_tensor, kdata.header) return img diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py new file mode 100644 index 000000000..c9a307ebe --- /dev/null +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -0,0 +1,139 @@ +"""Regularized Iterative SENSE Reconstruction by adjoint Fourier transform.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace +from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.data._kdata.KData import KData +from mrpro.data.CsmData import CsmData +from mrpro.data.DcfData import DcfData +from mrpro.data.IData import IData +from mrpro.data.KNoise import KNoise +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator + + +class RegularizedIterativeSENSEReconstruction(DirectReconstruction): + r"""Regularized iterative SENSE reconstruction. + + This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2 + + \frac{1}{2}L||Bx - x_0||_2^2` + by using a conjugate gradient algorithm to solve + :math:`H x = b` with :math:`H = A^H W A + L B` and :math:`b = A^H W y + L x_0` where :math:`A` + is the acquisition model (coil sensitivity maps, Fourier operator, k-space sampling), :math:`y` is the acquired + k-space data, :math:`W` describes the density compensation, :math:`L` is the strength of the regularization and + :math:`x_0` is the regularization image (i.e. the prior). :math:`B` is a linear operator applied to :math:`x`. + """ + + n_iterations: int + """Number of CG iterations.""" + + regularization_data: torch.Tensor + """Regularization data (i.e. prior) :math:`x_0`.""" + + regularization_weight: torch.Tensor + """Strength of the regularization :math:`L`.""" + + regularization_op: LinearOperator + """Linear operator :math:`B` applied to the current estimate in the regularization term.""" + + def __init__( + self, + kdata: KData | None = None, + fourier_op: LinearOperator | None = None, + csm: Callable | CsmData | None = CsmData.from_idata_walsh, + noise: KNoise | None = None, + dcf: DcfData | None = None, + *, + n_iterations: int = 5, + regularization_data: float | torch.Tensor = 0.0, + regularization_weight: float | torch.Tensor, + regularization_op: LinearOperator | None = None, + ) -> None: + """Initialize RegularizedIterativeSENSEReconstruction. + + For a unregularized version of the iterative SENSE algorithm the regularization_weight can be set to 0 or + IterativeSENSEReconstruction algorithm can be used. + + Parameters + ---------- + kdata + KData. If kdata is provided and fourier_op or dcf are None, then fourier_op and dcf are estimated based on + kdata. Otherwise fourier_op and dcf are used as provided. + fourier_op + Instance of the FourierOperator used for reconstruction. If None, set up based on kdata. + csm + Sensitivity maps for coil combination. If None, no coil combination is carried out, i.e. images for each + coil are returned. If a callable is provided, coil images are reconstructed using the adjoint of the + FourierOperator (including density compensation) and then sensitivity maps are calculated using the + callable. For this, kdata needs also to be provided. For examples have a look at the CsmData class + e.g. from_idata_walsh or from_idata_inati. + noise + KNoise used for prewhitening. If None, no prewhitening is performed + dcf + K-space sampling density compensation. If None, set up based on kdata. + n_iterations + Number of CG iterations + regularization_data + Regularization data, e.g. a reference image (:math:`x_0`). + regularization_weight + Strength of the regularization (:math:`L`). + regularization_op + Linear operator :math:`B` applied to the current estimate in the regularization term. If None, nothing is + applied to the current estimate. + + + Raises + ------ + ValueError + If the kdata and fourier_op are None or if csm is a Callable but kdata is None. + """ + super().__init__(kdata, fourier_op, csm, noise, dcf) + self.n_iterations = n_iterations + self.regularization_data = torch.as_tensor(regularization_data) + self.regularization_weight = torch.as_tensor(regularization_weight) + self.regularization_op = regularization_op if regularization_op is not None else IdentityOp() + + def forward(self, kdata: KData) -> IData: + """Apply the reconstruction. + + Parameters + ---------- + kdata + k-space data to reconstruct. + + Returns + ------- + the reconstruced image. + """ + if self.noise is not None: + kdata = prewhiten_kspace(kdata, self.noise) + + # Create the normal operator as H = A^H W A if the DCF is not None otherwise as H = A^H A. + # The acquisition model is A = F S if the CSM S is defined otherwise A = F with the Fourier operator F + csm_op = self.csm.as_operator() if self.csm is not None else IdentityOp() + precondition_op = self.dcf.as_operator() if self.dcf is not None else IdentityOp() + operator = (self.fourier_op @ csm_op).H @ precondition_op @ (self.fourier_op @ csm_op) + + # Calculate the right-hand-side as b = A^H W y if the DCF is not None otherwise as b = A^H y. + (right_hand_side,) = (self.fourier_op @ csm_op).H(precondition_op(kdata.data)[0]) + + # Add regularization + if not torch.all(self.regularization_weight == 0): + operator = operator + IdentityOp() @ (self.regularization_weight * self.regularization_op) + right_hand_side += self.regularization_weight * self.regularization_data + + img_tensor = cg( + operator, + right_hand_side, + initial_value=right_hand_side, + max_iterations=self.n_iterations, + tolerance=0.0, + ) + img = IData.from_tensor_and_kheader(img_tensor, kdata.header) + return img diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 1d292cf37..38b539f8b 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,4 +1,10 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import RegularizedIterativeSENSEReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction -__all__ = ["Reconstruction", "DirectReconstruction", "IterativeSENSEReconstruction"] +__all__ = [ + "DirectReconstruction", + "IterativeSENSEReconstruction", + "Reconstruction", + "RegularizedIterativeSENSEReconstruction" +] diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 525c164e0..f5d677f97 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,28 +1,29 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass -from typing import Self, TypeVar import ismrmrd import numpy as np import torch +from einops import rearrange +from typing_extensions import Self from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.unit_conversion import mm_to_m -# Conversion functions for units -T = TypeVar('T', float, torch.Tensor) +def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object: + """Change the shape of the fields in AcqInfo.""" + if isinstance(field, Rotation): + return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths)) -def ms_to_s(ms: T) -> T: - """Convert ms to s.""" - return ms / 1000 + if isinstance(field, torch.Tensor): + return rearrange(field, pattern, **axes_lengths) - -def mm_to_m(m: T) -> T: - """Convert mm to m.""" - return m / 1000 + return field @dataclass(slots=True) @@ -121,30 +122,24 @@ class AcqInfo(MoveDataMixin): number_of_samples: torch.Tensor """Number of sample points per readout (readouts may have different number of sample points).""" + orientation: Rotation + """Rotation describing the orientation of the readout, phase and slice encoding direction.""" + patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - phase_dir: SpatialDimension[torch.Tensor] - """Directional cosine of phase encoding (2D).""" - physiology_time_stamp: torch.Tensor """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] """Center of the excited volume, in LPS coordinates relative to isocenter [m].""" - read_dir: SpatialDimension[torch.Tensor] - """Directional cosine of readout/frequency encoding.""" - sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - slice_dir: SpatialDimension[torch.Tensor] - """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" @@ -206,15 +201,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32)), conversion) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -249,14 +242,16 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), - phase_dir=spatialdimension_2d(headers['phase_dir']), + orientation=Rotation.from_directions( + spatialdimension_2d(headers['slice_dir']), + spatialdimension_2d(headers['phase_dir']), + spatialdimension_2d(headers['read_dir']), + ), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), - read_dir=spatialdimension_2d(headers['read_dir']), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - slice_dir=spatialdimension_2d(headers['slice_dir']), trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above user_float=tensor_2d(headers['user_float']), user_int=tensor_2d(headers['user_int']), diff --git a/src/mrpro/data/CsmData.py b/src/mrpro/data/CsmData.py index 2913d1d20..000884f4b 100644 --- a/src/mrpro/data/CsmData.py +++ b/src/mrpro/data/CsmData.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING import torch +from typing_extensions import Self from mrpro.data.IData import IData from mrpro.data.QData import QData diff --git a/src/mrpro/data/Data.py b/src/mrpro/data/Data.py index b82a7cc53..cb0f3f3f5 100644 --- a/src/mrpro/data/Data.py +++ b/src/mrpro/data/Data.py @@ -2,9 +2,9 @@ import dataclasses from abc import ABC -from typing import Any import torch +from typing_extensions import Any from mrpro.data.MoveDataMixin import MoveDataMixin diff --git a/src/mrpro/data/DcfData.py b/src/mrpro/data/DcfData.py index baeb64db6..62726744d 100644 --- a/src/mrpro/data/DcfData.py +++ b/src/mrpro/data/DcfData.py @@ -4,9 +4,10 @@ import dataclasses from functools import reduce -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING import torch +from typing_extensions import Self from mrpro.algorithms.dcf.dcf_voronoi import dcf_1d, dcf_2d3d_voronoi from mrpro.data.KTrajectory import KTrajectory diff --git a/src/mrpro/data/EncodingLimits.py b/src/mrpro/data/EncodingLimits.py index 546cff883..daaa4bc6c 100644 --- a/src/mrpro/data/EncodingLimits.py +++ b/src/mrpro/data/EncodingLimits.py @@ -2,9 +2,9 @@ import dataclasses from dataclasses import dataclass -from typing import Self from ismrmrd.xsd.ismrmrdschema.ismrmrd import encodingLimitsType, limitType +from typing_extensions import Self @dataclass(slots=True) diff --git a/src/mrpro/data/IData.py b/src/mrpro/data/IData.py index 58ff29135..0c3cad043 100644 --- a/src/mrpro/data/IData.py +++ b/src/mrpro/data/IData.py @@ -3,7 +3,6 @@ import dataclasses from collections.abc import Generator, Sequence from pathlib import Path -from typing import Self import numpy as np import torch @@ -11,6 +10,7 @@ from pydicom import dcmread from pydicom.dataset import Dataset from pydicom.tag import TagType +from typing_extensions import Self from mrpro.data.Data import Data from mrpro.data.IHeader import IHeader diff --git a/src/mrpro/data/IHeader.py b/src/mrpro/data/IHeader.py index f500863b4..54eb98e62 100644 --- a/src/mrpro/data/IHeader.py +++ b/src/mrpro/data/IHeader.py @@ -3,12 +3,12 @@ import dataclasses from collections.abc import Sequence from dataclasses import dataclass -from typing import Self import numpy as np import torch from pydicom.dataset import Dataset from pydicom.tag import Tag, TagType +from typing_extensions import Self from mrpro.data.KHeader import KHeader from mrpro.data.MoveDataMixin import MoveDataMixin @@ -114,7 +114,7 @@ def deg_to_rad(deg: torch.Tensor | None) -> torch.Tensor | None: get_items_from_all_dicoms('PixelSpacing')[0][1], ) fov_z_mm = get_float_items_from_all_dicoms('SliceThickness')[0] - fov = SpatialDimension(fov_x_mm / 1000.0, fov_y_mm / 1000.0, fov_z_mm / 1000.0) + fov = SpatialDimension(fov_x_mm, fov_y_mm, fov_z_mm) / 1000 # convert to m # Get misc parameters misc = {} diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 0cf5fcc87..dea12e28b 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -6,18 +6,19 @@ import datetime import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING import ismrmrd.xsd.ismrmrdschema.ismrmrd as ismrmrdschema import torch +from typing_extensions import Self from mrpro.data import enums -from mrpro.data.AcqInfo import AcqInfo, mm_to_m, ms_to_s +from mrpro.data.AcqInfo import AcqInfo from mrpro.data.EncodingLimits import EncodingLimits from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues +from mrpro.utils.unit_conversion import mm_to_m, ms_to_s if TYPE_CHECKING: # avoid circular imports by importing only when type checking @@ -39,9 +40,6 @@ class KHeader(MoveDataMixin): trajectory: KTrajectoryCalculator """Function to calculate the k-space trajectory.""" - b0: float - """Magnetic field strength [T].""" - encoding_limits: EncodingLimits """K-space encoding limits.""" @@ -60,12 +58,9 @@ class KHeader(MoveDataMixin): acq_info: AcqInfo """Information of the acquisitions (i.e. readout lines).""" - h1_freq: float + lamor_frequency_proton: float """Lamor frequency of hydrogen nuclei [Hz].""" - n_coils: int | None = None - """Number of receiver coils.""" - datetime: datetime.datetime | None = None """Date and time of acquisition.""" @@ -87,7 +82,7 @@ class KHeader(MoveDataMixin): echo_train_length: int = 1 """Number of echoes in a multi-echo acquisition.""" - seq_type: str = UNKNOWN + sequence_type: str = UNKNOWN """Type of sequence.""" model: str = UNKNOWN @@ -99,16 +94,13 @@ class KHeader(MoveDataMixin): protocol_name: str = UNKNOWN """Name of the acquisition protocol.""" - misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! - """Dictionary with miscellaneous parameters.""" - calibration_mode: enums.CalibrationMode = enums.CalibrationMode.OTHER """Mode of how calibration data is acquired. """ interleave_dim: enums.InterleavingDimension = enums.InterleavingDimension.OTHER """Interleaving dimension.""" - traj_type: enums.TrajectoryType = enums.TrajectoryType.OTHER + trajectory_type: enums.TrajectoryType = enums.TrajectoryType.OTHER """Type of trajectory.""" measurement_id: str = UNKNOWN @@ -117,8 +109,9 @@ class KHeader(MoveDataMixin): patient_name: str = UNKNOWN """Name of the patient.""" - trajectory_description: TrajectoryDescription = dataclasses.field(default_factory=TrajectoryDescription) - """Description of the trajectory.""" + _misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! + """Dictionary with miscellaneous parameters. These parameters are for information purposes only. Reconstruction + algorithms should not rely on them.""" @property def fa_degree(self) -> torch.Tensor | None: @@ -159,17 +152,14 @@ def from_ismrmrd( enc: ismrmrdschema.encodingType = header.encoding[encoding_number] # These are guaranteed to exist - parameters = {'h1_freq': header.experimentalConditions.H1resonanceFrequency_Hz, 'acq_info': acq_info} + parameters = { + 'lamor_frequency_proton': header.experimentalConditions.H1resonanceFrequency_Hz, + 'acq_info': acq_info, + } if defaults is not None: parameters.update(defaults) - if ( - header.acquisitionSystemInformation is not None - and header.acquisitionSystemInformation.receiverChannels is not None - ): - parameters['n_coils'] = header.acquisitionSystemInformation.receiverChannels - if header.sequenceParameters is not None: if header.sequenceParameters.TR: parameters['tr'] = ms_to_s(torch.as_tensor(header.sequenceParameters.TR)) @@ -183,14 +173,16 @@ def from_ismrmrd( parameters['echo_spacing'] = ms_to_s(torch.as_tensor(header.sequenceParameters.echo_spacing)) if header.sequenceParameters.sequence_type is not None: - parameters['seq_type'] = header.sequenceParameters.sequence_type + parameters['sequence_type'] = header.sequenceParameters.sequence_type if enc.reconSpace is not None: - parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm, mm_to_m) + parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm).apply_(mm_to_m) parameters['recon_matrix'] = SpatialDimension[int].from_xyz(enc.reconSpace.matrixSize) if enc.encodedSpace is not None: - parameters['encoding_fov'] = SpatialDimension[float].from_xyz(enc.encodedSpace.fieldOfView_mm, mm_to_m) + parameters['encoding_fov'] = ( + SpatialDimension[float].from_xyz(enc.encodedSpace.fieldOfView_mm).apply_(mm_to_m) + ) parameters['encoding_matrix'] = SpatialDimension[int].from_xyz(enc.encodedSpace.matrixSize) if enc.encodingLimits is not None: @@ -209,7 +201,7 @@ def from_ismrmrd( ) if enc.trajectory is not None: - parameters['traj_type'] = enums.TrajectoryType(enc.trajectory.value) + parameters['trajectory_type'] = enums.TrajectoryType(enc.trajectory.value) # Either use the series or study time if available if header.measurementInformation is not None and header.measurementInformation.seriesTime is not None: @@ -242,15 +234,8 @@ def from_ismrmrd( if header.acquisitionSystemInformation.systemModel is not None: parameters['model'] = header.acquisitionSystemInformation.systemModel - if header.acquisitionSystemInformation.systemFieldStrength_T is not None: - parameters['b0'] = header.acquisitionSystemInformation.systemFieldStrength_T - - # estimate b0 from h1_freq if not given - if 'b0' not in parameters: - parameters['b0'] = parameters['h1_freq'] / 4258e4 - # Dump everything into misc - parameters['misc'] = dataclasses.asdict(header) + parameters['_misc'] = dataclasses.asdict(header) if overwrite is not None: parameters.update(overwrite) diff --git a/src/mrpro/data/KNoise.py b/src/mrpro/data/KNoise.py index 4e7e8fb7b..ec606114d 100644 --- a/src/mrpro/data/KNoise.py +++ b/src/mrpro/data/KNoise.py @@ -3,11 +3,11 @@ import dataclasses from collections.abc import Callable from pathlib import Path -from typing import Self import ismrmrd import torch from einops import repeat +from typing_extensions import Self from mrpro.data.acq_filters import is_noise_acquisition from mrpro.data.MoveDataMixin import MoveDataMixin diff --git a/src/mrpro/data/KTrajectory.py b/src/mrpro/data/KTrajectory.py index f458d0d31..8cbe207cc 100644 --- a/src/mrpro/data/KTrajectory.py +++ b/src/mrpro/data/KTrajectory.py @@ -1,10 +1,10 @@ """KTrajectory dataclass.""" from dataclasses import dataclass -from typing import Self import numpy as np import torch +from typing_extensions import Self from mrpro.data.enums import TrajType from mrpro.data.MoveDataMixin import MoveDataMixin diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 90f189e4c..2ac3c1a58 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,12 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import Any, ClassVar, Protocol, Self, TypeAlias, overload, runtime_checkable +from typing import ClassVar, TypeAlias, cast import torch +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -21,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -150,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -178,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -218,26 +223,80 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) + return new + + def apply( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + recurse: bool = True, + ) -> Self: + """Apply a function to all children. Returns a new object. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + new = self.clone().apply_(function, recurse=recurse) return new + def apply_( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/QData.py b/src/mrpro/data/QData.py index c55a55fa4..04f4ca1f4 100644 --- a/src/mrpro/data/QData.py +++ b/src/mrpro/data/QData.py @@ -2,12 +2,12 @@ import dataclasses from pathlib import Path -from typing import Self import numpy as np import torch from einops import repeat from pydicom import dcmread +from typing_extensions import Self from mrpro.data.Data import Data from mrpro.data.IHeader import IHeader diff --git a/src/mrpro/data/QHeader.py b/src/mrpro/data/QHeader.py index b3aec710c..e21453ca8 100644 --- a/src/mrpro/data/QHeader.py +++ b/src/mrpro/data/QHeader.py @@ -1,10 +1,10 @@ """MR quantitative data header (QHeader) dataclass.""" from dataclasses import dataclass -from typing import Self from pydicom.dataset import Dataset from pydicom.tag import Tag +from typing_extensions import Self from mrpro.data.IHeader import IHeader from mrpro.data.KHeader import KHeader @@ -60,5 +60,5 @@ def get_items(name: str): # Todo: move to utils and reuse logic in IHeader fov_x_mm = float(get_items('Rows')[0]) * get_items('PixelSpacing')[0][0] fov_y_mm = float(get_items('Columns')[0]) * get_items('PixelSpacing')[0][1] fov_z_mm = float(get_items('SliceThickness')[0]) - fov = SpatialDimension(fov_x_mm / 1000.0, fov_y_mm / 1000.0, fov_z_mm / 1000.0) + fov = SpatialDimension(fov_x_mm, fov_y_mm, fov_z_mm) / 1000 # convert to m return cls(fov=fov) diff --git a/src/mrpro/data/Rotation.py b/src/mrpro/data/Rotation.py index 706c298de..628d93c7e 100644 --- a/src/mrpro/data/Rotation.py +++ b/src/mrpro/data/Rotation.py @@ -1,8 +1,8 @@ """A pytorch implementation of scipy.spatial.transform.Rotation. -A container for Rotations, that can be created from quaternions, euler angles, rotation vectors, rotation matrices, -etc, can be applied to torch.Tensors and SpatialDimensions, multiplied, and can be converted to quaternions, -euler angles, etc. +A container for proper and improper Rotations, that can be created from quaternions, euler angles, rotation vectors, +rotation matrices, etc, can be applied to torch.Tensors and SpatialDimensions, multiplied, and can be converted +to quaternions, euler angles, etc. see also https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx """ @@ -42,18 +42,22 @@ from __future__ import annotations +import math import re import warnings -from collections.abc import Sequence -from typing import Literal, Self, overload +from collections.abc import Callable, Sequence +from typing import Literal, cast import numpy as np import torch +import torch.nn.functional as F # noqa: N812 +from einops import rearrange from scipy._lib._util import check_random_state -from scipy.spatial.transform import Rotation as Rotation_scipy +from typing_extensions import Self, Unpack, overload from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.utils.typing import IndexerType, NestedSequence +from mrpro.utils.typing import NestedSequence, TorchIndexerType +from mrpro.utils.vmf import sample_vmf AXIS_ORDER = 'zyx' # This can be modified QUAT_AXIS_ORDER = AXIS_ORDER + 'w' # Do not modify @@ -170,6 +174,31 @@ def _quaternion_to_matrix(quaternion: torch.Tensor) -> torch.Tensor: return matrix +def _quaternion_to_axis_angle(quaternion: torch.Tensor, degrees: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + """Convert quaternion to rotation axis and angle. + + Parameters + ---------- + quaternion + The batched quaternions, shape (..., 4) + degrees + If True, the angle is returned in degrees, otherwise in radians. + + Returns + ------- + axis + The rotation axis, shape (..., 3) + angle + The rotation angle, shape (...) + """ + quaternion = _canonical_quaternion(quaternion) + angle = 2 * torch.atan2(torch.linalg.vector_norm(quaternion[..., :3], dim=-1), quaternion[..., 3]) + axis = quaternion[..., :3] / torch.linalg.vector_norm(quaternion[..., :3], dim=-1, keepdim=True) + if degrees: + angle = torch.rad2deg(angle) + return axis, angle + + def _quaternion_to_euler(quaternion: torch.Tensor, seq: str, extrinsic: bool): """Convert quaternion to euler angles. @@ -241,6 +270,114 @@ def _quaternion_to_euler(quaternion: torch.Tensor, seq: str, extrinsic: bool): return angles +def _align_vectors( + a: torch.Tensor, + b: torch.Tensor, + weights: torch.Tensor, + return_sensitivity: bool = False, + allow_improper: bool = False, +): + """Estimate a rotation to optimally align two sets of vectors.""" + n_vecs = a.shape[0] + if a.shape != b.shape: + raise ValueError(f'Expected inputs to have same shapes, got {a.shape} and {b.shape}') + if a.shape[-1] != 3: + raise ValueError(f'Expected inputs to have shape (..., 3), got {a.shape} and {b.shape}') + if weights.shape != (n_vecs,) or (weights < 0).any(): + raise ValueError(f'Invalid weights: expected shape ({n_vecs},) with non-negative values') + if (a.norm(dim=-1) < 1e-6).any() or (b.norm(dim=-1) < 1e-6).any(): + raise ValueError('Cannot align zero length primary vectors') + dtype = torch.result_type(a, b) + # we require double precision for the calculations to match scipy results + weights = weights.double() + a = a.double() + b = b.double() + + inf_mask = torch.isinf(weights) + if inf_mask.sum() > 1: + raise ValueError('Only one infinite weight is allowed') + + if inf_mask.any() or n_vecs == 1: + # special case for one vector pair or one infinite weight + + if return_sensitivity: + raise ValueError('Cannot return sensitivity matrix with an infinite weight or one vector pair') + + a_primary, b_primary = (a[0], b[0]) if n_vecs == 1 else (a[inf_mask][0], b[inf_mask][0]) + a_primary, b_primary = F.normalize(a_primary, dim=0), F.normalize(b_primary, dim=0) + cross = torch.linalg.cross(b_primary, a_primary, dim=0) + angle = torch.atan2(torch.norm(cross), torch.dot(a_primary, b_primary)) + rot_primary = _axisangle_to_matrix(cross, angle) + + if n_vecs == 1: + return rot_primary.to(dtype), torch.tensor(0.0, device=a.device, dtype=dtype) + + a_secondary, b_secondary = a[~inf_mask], b[~inf_mask] + sec_w = weights[~inf_mask] + rot_sec_b = (rot_primary @ b_secondary.T).T + sin_term = torch.einsum('ij,j->i', torch.linalg.cross(rot_sec_b, a_secondary, dim=1), a_primary) + cos_term = torch.einsum('ij,ij->i', rot_sec_b, a_secondary) - torch.einsum( + 'ij,j->i', rot_sec_b, a_primary + ) * torch.einsum('ij,j->i', a_secondary, a_primary) + + phi = torch.atan2((sec_w * sin_term).sum(), (sec_w * cos_term).sum()) + rot_secondary = _axisangle_to_matrix(a_primary, phi) + rot_optimal = rot_secondary @ rot_primary + rssd_w = weights.clone() + rssd_w[inf_mask] = 0 + est_a = (rot_optimal @ b.T).T + rssd = torch.sqrt(torch.sum(rssd_w * torch.sum((a - est_a) ** 2, dim=1))) + return rot_optimal.to(dtype), rssd.to(dtype) + + corr_mat = torch.einsum('i j, i k, i -> j k', a, b, weights) + u, s, vt = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.linalg.svd(corr_mat)) + if s[1] + s[2] < 1e-16 * s[0]: + warnings.warn('Optimal rotation is not uniquely or poorly defined for the given sets of vectors.', stacklevel=2) + + if (u @ vt).det() < 0 and not allow_improper: + u[:, -1] *= -1 + + rot_optimal = (u @ vt).to(dtype) + rssd = ((weights * (b**2 + a**2).sum(dim=1)).sum() - 2 * s.sum()).clamp_min(0.0).sqrt().to(dtype) + + if return_sensitivity: + zeta = (s[0] + s[1]) * (s[1] + s[2]) * (s[2] + s[0]) + kappa = s[0] * s[1] + s[1] * s[2] + s[2] * s[0] + sensitivity = ( + weights.mean() / zeta * (kappa * torch.eye(3, device=a.device, dtype=torch.float64) + corr_mat @ corr_mat.T) + ).to(dtype) + return rot_optimal, rssd, sensitivity + + return rot_optimal, rssd + + +def _axisangle_to_matrix(axis: torch.Tensor, angle: torch.Tensor) -> torch.Tensor: + """Compute a rotation matrix using Rodrigues' rotation formula.""" + axis = F.normalize(axis, dim=-1, eps=1e-6) + cos, sin = torch.cos(angle), torch.sin(angle) + t = 1 - cos + q, r, s = axis.unbind(-1) + matrix = rearrange( + torch.stack( + [ + t * q * q + cos, + t * q * r - s * sin, + t * q * s + r * sin, + t * q * r + s * sin, + t * r * r + cos, + t * r * s - q * sin, + t * q * s - r * sin, + t * r * s + q * sin, + t * s * s + cos, + ], + dim=-1, + ), + '... (row col) -> ... row col', + row=3, + ) + return matrix + + class Rotation(torch.nn.Module): """A container for Rotations. @@ -251,12 +388,20 @@ class Rotation(torch.nn.Module): Differences compared to scipy.spatial.transform.Rotation: - torch.nn.Module based, the quaternions are a Parameter - - .apply is replaced by call/forward. - not all features are implemented. Notably, mrp, davenport, and reduce are missing. - arbitrary number of batching dimensions + - support for improper rotations (rotoinversion), i.e., rotations with an coordinate inversion + or a reflection about a plane perpendicular to the rotation axis. """ - def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: bool = True, copy: bool = True): + def __init__( + self, + quaternions: torch.Tensor | NestedSequence[float], + normalize: bool = True, + copy: bool = True, + inversion: torch.Tensor | NestedSequence[bool] | bool = False, + reflection: torch.Tensor | NestedSequence[bool] | bool = False, + ) -> None: """Initialize a new Rotation. Instead of calling this method, also consider the different ``from_*`` class methods to construct a Rotation. @@ -266,12 +411,23 @@ def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: quaternions Rotatation quaternions. If these requires_grad, the resulting Rotation will require gradients normalize - If the quaternions should be normalized. Only disable if you are sure the quaternions are already normalized + If the quaternions should be normalized. Only disable if you are sure the quaternions are already + normalized. + Will keep a possible negative w to represent improper rotations. copy Always ensure that a copy of the quaternions is created. If both normalize and copy are False, the quaternions Parameter of this instance will be a view if the quaternions passed in. + inversion + If the rotation should contain an inversion of the coordinate system, i.e. a reflection of all three axes, + resulting in a rotoinversion (improper rotation). + If a boolean tensor is given, it should broadcast with the quaternions. + reflection + If the rotation should contain a reflection about a plane perpendicular to the rotation axis. + This will result in a rotoflexion (improper rotation). + If a boolean tensor is given, it should broadcast with the quaternions. """ super().__init__() + quaternions_ = torch.as_tensor(quaternions) if torch.is_complex(quaternions_): raise ValueError('quaternions should be real numbers') @@ -281,11 +437,27 @@ def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: if quaternions_.shape[-1] != 4: raise ValueError('Expected `quaternions` to have shape (..., 4), ' f'got {quaternions_.shape}.') + reflection_ = torch.as_tensor(reflection) + inversion_ = torch.as_tensor(inversion) + if reflection_.any(): + axis, angle = _quaternion_to_axis_angle(quaternions_) + angle = (angle + torch.pi * reflection_.float()).unsqueeze(-1) + is_improper = inversion_ ^ reflection_ + quaternions_ = torch.cat((torch.sin(angle / 2) * axis, torch.cos(angle / 2)), -1) + elif inversion_.any(): + is_improper = inversion_ + else: + is_improper = torch.zeros_like(quaternions_[..., 0], dtype=torch.bool) + + batchsize = torch.broadcast_shapes(quaternions_.shape[:-1], is_improper.shape) + is_improper = is_improper.expand(batchsize) + # If a single quaternion is given, convert it to a 2D 1 x 4 matrix but # set self._single to True so that we can return appropriate objects # in the `to_...` methods if quaternions_.shape == (4,): quaternions_ = quaternions_[None, :] + is_improper = is_improper[None] self._single = True else: self._single = False @@ -296,19 +468,56 @@ def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: raise ValueError('Found zero norm quaternion in `quaternions`.') quaternions_ = quaternions_ / norms elif copy: + # no need to clone if we are normalizing quaternions_ = quaternions_.clone() + if copy: + is_improper = is_improper.clone() + + if is_improper.requires_grad: + warnings.warn('Rotation is not differentiable in the improper parameter.', stacklevel=2) + self._quaternions = torch.nn.Parameter(quaternions_, quaternions_.requires_grad) + self._is_improper = torch.nn.Parameter(is_improper, False) @property def single(self) -> bool: """Returns true if this a single rotation.""" return self._single + @property + def is_improper(self) -> torch.Tensor: + """Returns a true boolean tensor if the rotation is improper.""" + return self._is_improper + + @is_improper.setter + def is_improper(self, improper: torch.Tensor | NestedSequence[bool] | bool) -> None: + """Set the improper parameter.""" + self._is_improper[:] = torch.as_tensor(improper, dtype=torch.bool, device=self._is_improper.device) + + @property + def det(self) -> torch.Tensor: + """Returns the determinant of the rotation matrix. + + Will be 1. for proper rotations and -1. for improper rotations. + """ + return self._is_improper.float() * -2 + 1 + @classmethod - def from_quat(cls, quaternions: torch.Tensor | NestedSequence[float]) -> Self: + def from_quat( + cls, + quaternions: torch.Tensor | NestedSequence[float], + inversion: torch.Tensor | NestedSequence[bool] | bool = False, + reflection: torch.Tensor | NestedSequence[bool] | bool = False, + ) -> Self: """Initialize from quaternions. 3D rotations can be represented using unit-norm quaternions [QUAa]_. + As an extension to the standard, this class also supports improper rotations, + i.e. rotations with reflection with respect to the plane perpendicular to the rotation axis + or inversion of the coordinate system. + + Note: If inversion != reflection, the rotation will be improper and save as a rotation followed by an inversion. + containing an inversion of the coordinate system. Parameters ---------- @@ -317,6 +526,12 @@ def from_quat(cls, quaternions: torch.Tensor | NestedSequence[float]) -> Self: Each row is a (possibly non-unit norm) quaternion representing an active rotation, in scalar-last (x, y, z, w) format. Each quaternion will be normalized to unit norm. + inversion + if the rotation should contain an inversion of the coordinate system, i.e. a reflection + of all three axes. If a boolean tensor is given, it should broadcast with the quaternions. + reflection + if the rotation should contain a reflection about a plane perpendicular to the rotation axis. + Returns ------- @@ -327,22 +542,26 @@ def from_quat(cls, quaternions: torch.Tensor | NestedSequence[float]) -> Self: ---------- .. [QUAa] Quaternions and spatial rotation https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation """ - if not isinstance(quaternions, torch.Tensor): - quaternions = torch.as_tensor(quaternions) - return cls(quaternions, normalize=True) + return cls(quaternions, normalize=True, copy=True, inversion=inversion, reflection=reflection) @classmethod - def from_matrix(cls, matrix: torch.Tensor | NestedSequence[float]) -> Self: + def from_matrix(cls, matrix: torch.Tensor | NestedSequence[float], allow_improper: bool = True) -> Self: """Initialize from rotation matrix. Rotations in 3 dimensions can be represented with 3 x 3 proper orthogonal matrices [ROTa]_. If the input is not proper orthogonal, an approximation is created using the method described in [MAR2008]_. + If the input matrix has a negative determinant, the rotation is considered + as improper, i.e. containing a reflection. The resulting rotation + will include this reflection [ROTb]_. Parameters ---------- matrix A single matrix or a stack of matrices, shape (..., 3, 3) + allow_improper + If true, the rotation is considered as improper if the determinant of the matrix is negative. + If false, an ValueError is raised if the determinant is negative. Returns ------- @@ -353,24 +572,92 @@ def from_matrix(cls, matrix: torch.Tensor | NestedSequence[float]) -> Self: References ---------- .. [ROTa] Rotation matrix https://en.wikipedia.org/wiki/Rotation_matrix#In_three_dimensions + .. [ROTb] Rotation matrix https://en.wikipedia.org/wiki/Improper_rotation .. [MAR2008] Landis Markley F (2008) Unit Quaternion from Rotation Matrix, Journal of guidance, control, and dynamics 31(2),440-442. """ - if not isinstance(matrix, torch.Tensor): - matrix = torch.as_tensor(matrix) - if matrix.shape[-2:] != (3, 3): - raise ValueError(f'Expected `matrix` to have shape (..., 3, 3), got {matrix.shape}') - if torch.is_complex(matrix): + matrix_ = torch.as_tensor(matrix) + if matrix_.shape[-2:] != (3, 3): + raise ValueError(f'Expected `matrix` to have shape (..., 3, 3), got {matrix_.shape}') + if torch.is_complex(matrix_): raise ValueError('matrix should be real, not complex.') - if not torch.is_floating_point(matrix): + if not torch.is_floating_point(matrix_): # integer or boolean dtypes - matrix = matrix.float() - quaternions = _matrix_to_quaternion(matrix) + matrix_ = matrix_.float() + + det = torch.linalg.det(matrix_) + improper = det < 0 + if improper.any(): + if not allow_improper: + raise ValueError( + 'Found negative determinant in `matrix`. ' + 'This would result in an improper rotation, but allow_improper is False.' + ) + matrix_ = matrix_ * det.unsqueeze(-1).unsqueeze(-1).sign() + + quaternions = _matrix_to_quaternion(matrix_) + + return cls(quaternions, normalize=True, copy=False, inversion=improper, reflection=False) + + @classmethod + def from_directions( + cls, *basis: Unpack[tuple[SpatialDimension, SpatialDimension, SpatialDimension]], allow_improper: bool = True + ): + """Initialize from basis vectors as SpatialDimensions. + + Parameters + ---------- + *basis + 3 Basis vectors of the new coordinate system, i.e. the columns of the rotation matrix + allow_improper + If true, the rotation is considered as improper if the determinant of the matrix is negative + and the sign will be preserved. If false, a ValueError is raised if the determinant is negative. + + + Returns + ------- + rotation + Object containing the rotations represented by the basis vectors. + """ + b1, b2, b3 = (torch.stack([torch.as_tensor(getattr(v_, axis)) for axis in AXIS_ORDER], -1) for v_ in basis) + matrix = torch.stack((b1, b2, b3), -1) + det = torch.linalg.det(matrix) + if not allow_improper and (det < 0).any(): + raise ValueError('The given basis vectors do not form a proper rotation matrix.') + if ((1 - det.abs()) > 0.1).any(): + raise ValueError('The given basis vectors do not form a rotation matrix.') + + return cls.from_matrix(matrix, allow_improper=allow_improper) + + def as_directions( + self, + ) -> tuple[SpatialDimension[torch.Tensor], SpatialDimension[torch.Tensor], SpatialDimension[torch.Tensor]]: + """Represent as the basis vectors of the new coordinate system as SpatialDimensions. - return cls(quaternions, normalize=True, copy=False) + Returns the three basis vectors of the new coordinate system after rotation, + i.e. the columns of the rotation matrix, as SpatialDimensions. + + Returns + ------- + basis + The basis vectors of the new coordinate system. + """ + matrix = self.as_matrix() + ret = ( + SpatialDimension(**dict(zip(AXIS_ORDER, matrix[..., 0].unbind(-1), strict=True))), + SpatialDimension(**dict(zip(AXIS_ORDER, matrix[..., 1].unbind(-1), strict=True))), + SpatialDimension(**dict(zip(AXIS_ORDER, matrix[..., 2].unbind(-1), strict=True))), + ) + return ret @classmethod - def from_rotvec(cls, rotvec: torch.Tensor | NestedSequence[float], degrees: bool = False) -> Self: + def from_rotvec( + cls, + rotvec: torch.Tensor | NestedSequence[float], + degrees: bool = False, + reflection: torch.Tensor | NestedSequence[bool] | bool = False, + inversion: torch.Tensor | NestedSequence[bool] | bool = False, + ) -> Self: """Initialize from rotation vector. A rotation vector is a 3 dimensional vector which is co-directional to the @@ -383,28 +670,56 @@ def from_rotvec(cls, rotvec: torch.Tensor | NestedSequence[float], degrees: bool degrees If True, then the given angles are assumed to be in degrees, otherwise radians. + reflection + If True, the resulting transformation will contain a reflection + about a plane perpendicular to the rotation axis, resulting in a rotoflection + (improper rotation). + inversion + If True, the resulting transformation will contain an inversion of the coordinate system, + resulting in a rotoinversion (improper rotation). + + Returns + ------- + rotation + Object containing the rotations represented by the rotation vectors. """ - if not isinstance(rotvec, torch.Tensor): - rotvec = torch.as_tensor(rotvec) - if torch.is_complex(rotvec): + rotvec_ = torch.as_tensor(rotvec) + reflection_ = torch.as_tensor(reflection) + inversion_ = torch.as_tensor(inversion) + if rotvec_.is_complex(): raise ValueError('rotvec should be real numbers') - if not torch.is_floating_point(rotvec): + if not rotvec_.is_floating_point(): # integer or boolean dtypes - rotvec = rotvec.float() + rotvec_ = rotvec_.float() if degrees: - rotvec = torch.deg2rad(rotvec) + rotvec_ = torch.deg2rad(rotvec_) - if rotvec.shape[-1] != 3: - raise ValueError(f'Expected `rot_vec` to have shape (..., 3), got {rotvec.shape}') + if rotvec_.shape[-1] != 3: + raise ValueError(f'Expected `rot_vec` to have shape (..., 3), got {rotvec_.shape}') - angles = torch.linalg.vector_norm(rotvec, dim=-1, keepdim=True) + angles = torch.linalg.vector_norm(rotvec_, dim=-1, keepdim=True) scales = torch.special.sinc(angles / (2 * torch.pi)) / 2 - quaternions = torch.cat((scales * rotvec, torch.cos(angles / 2)), -1) - return cls(quaternions, normalize=False, copy=False) + quaternions = torch.cat((scales * rotvec_, torch.cos(angles / 2)), -1) + if reflection_.any(): + # we can do it here and avoid the extra of converting to quaternions, + # back to axis-angle and then to quaternions. + inversion_ = reflection_ ^ inversion_ + scales = torch.cos(0.5 * angles) / angles + reflected_quaternions = torch.cat((scales * rotvec_, -torch.sin(angles / 2)), -1) + quaternions = torch.where(reflection_, reflected_quaternions, quaternions) + + return cls(quaternions, normalize=False, copy=False, inversion=inversion_, reflection=False) @classmethod - def from_euler(cls, seq: str, angles: torch.Tensor | NestedSequence[float] | float, degrees: bool = False) -> Self: + def from_euler( + cls, + seq: str, + angles: torch.Tensor | NestedSequence[float] | float, + degrees: bool = False, + inversion: torch.Tensor | NestedSequence[bool] | bool = False, + reflection: torch.Tensor | NestedSequence[bool] | bool = False, + ) -> Self: """Initialize from Euler angles. Rotations in 3-D can be represented by a sequence of 3 @@ -430,6 +745,13 @@ def from_euler(cls, seq: str, angles: torch.Tensor | NestedSequence[float] | flo degrees If True, then the given angles are assumed to be in degrees. Otherwise they are assumed to be in radians + inversion + If True, the resulting transformation will contain an inversion of the coordinate system, + resulting in a rotoinversion (improper rotation). + reflection + If True, the resulting transformation will contain a reflection + about a plane perpendicular to the rotation axis, resulting in an + improper rotation. Returns ------- @@ -476,9 +798,9 @@ def from_euler(cls, seq: str, angles: torch.Tensor | NestedSequence[float] | flo quaternions = _compose_quaternions(_make_elementary_quat(axis, angle), quaternions) if is_single: - return cls(quaternions[0], normalize=False, copy=False) + return cls(quaternions[0], normalize=False, copy=False, inversion=inversion, reflection=reflection) else: - return cls(quaternions, normalize=False, copy=False) + return cls(quaternions, normalize=False, copy=False, inversion=inversion, reflection=reflection) @classmethod def from_davenport(cls, axes: torch.Tensor, order: str, angles: torch.Tensor, degrees: bool = False): @@ -490,7 +812,21 @@ def from_mrp(cls, mrp: torch.Tensor) -> Self: """Not implemented.""" raise NotImplementedError - def as_quat(self, canonical: bool = False) -> torch.Tensor: + @overload + def as_quat( + self, canonical: bool = ..., *, improper: Literal['warn'] | Literal['ignore'] = 'warn' + ) -> torch.Tensor: ... + @overload + def as_quat( + self, canonical: bool = ..., *, improper: Literal['reflection'] | Literal['inversion'] + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + def as_quat( + self, + canonical: bool = False, + *, + improper: Literal['reflection'] | Literal['inversion'] | Literal['ignore'] | Literal['warn'] = 'warn', + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Represent as quaternions. Active rotations in 3 dimensions can be represented using unit norm @@ -507,28 +843,71 @@ def as_quat(self, canonical: bool = False) -> torch.Tensor: chosen from {q, -q} such that the w term is positive. If the w term is 0, then the quaternion is chosen such that the first nonzero term of the x, y, and z terms is positive. + improper + How to handle improper rotations. If 'warn', a warning is raised if + the rotation is improper. If 'ignore', the reflection information is + discarded. If 'reflection' or 'inversion', additional information is + returned in the form of a boolean tensor indicating if the rotation + is improper. + If 'reflection', the boolean tensor indicates if the rotation contains + a reflection about a plane perpendicular to the rotation axis. + Note that this required additional computation. + If 'inversion', the boolean tensor indicates if the rotation contains + an inversion of the coordinate system. + The quaternion is adjusted to represent the rotation to be performed + before the reflection or inversion. Returns ------- quaternions shape (..., 4,), depends on shape of inputs used for initialization. + (optional) reflection (if improper is 'reflection') or inversion (if improper is 'inversion') + boolean tensor of shape (...,), indicating if the rotation is improper + and if a reflection or inversion should be performed after the rotation. References ---------- .. [QUAb] Quaternions https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation """ quaternions: torch.Tensor = self._quaternions - if canonical: - quaternions = _canonical_quaternion(quaternions) + is_improper: torch.Tensor = self._is_improper + + if improper == 'warn': + if is_improper.any(): + warnings.warn( + 'Rotation contains improper rotations. Set `improper="reflection"` or `improper="inversion"` ' + 'to get reflection or inversion information.', + stacklevel=2, + ) + elif improper == 'ignore' or improper == 'inversion': + ... + elif improper == 'reflection': + axis, angle = _quaternion_to_axis_angle(quaternions) + angle = (angle + torch.pi * is_improper.float()).unsqueeze(-1) + quaternions = torch.cat((torch.sin(angle / 2) * axis, torch.cos(angle / 2)), -1) + else: + raise ValueError(f'Invalid improper value: {improper}') + if self.single: quaternions = quaternions[0] - return quaternions + is_improper = is_improper[0] + + if canonical: + quaternions = _canonical_quaternion(quaternions) + else: + quaternions = quaternions.clone() + + if improper == 'reflection' or improper == 'inversion': + return quaternions, is_improper + else: + return quaternions def as_matrix(self) -> torch.Tensor: """Represent as rotation matrix. 3D rotations can be represented using rotation matrices, which - are 3 x 3 real orthogonal matrices with determinant equal to +1 [ROTb]_. + are 3 x 3 real orthogonal matrices with determinant equal to +1 [ROTb]_ + for proper rotations and -1 for improper rotations. Returns ------- @@ -541,12 +920,28 @@ def as_matrix(self) -> torch.Tensor: """ quaternions = self._quaternions matrix = _quaternion_to_matrix(quaternions) + if self._is_improper.any(): + matrix = matrix * self.det.unsqueeze(-1).unsqueeze(-1) + if self._single: return matrix[0] else: return matrix - def as_rotvec(self, degrees: bool = False) -> torch.Tensor: + @overload + def as_rotvec( + self, degrees: bool = ..., *, improper: Literal['ignore'] | Literal['warn'] = 'warn' + ) -> torch.Tensor: ... + @overload + def as_rotvec( + self, degrees: bool = ..., *, improper: Literal['reflection'] | Literal['inversion'] + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + def as_rotvec( + self, + degrees: bool = False, + improper: Literal['reflection'] | Literal['inversion'] | Literal['ignore'] | Literal['warn'] = 'warn', + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Represent as rotation vectors. A rotation vector is a 3 dimensional vector which is co-directional to @@ -556,33 +951,69 @@ def as_rotvec(self, degrees: bool = False) -> torch.Tensor: ---------- degrees Returned magnitudes are in degrees if this flag is True, else they are in radians + improper + How to handle improper rotations. If 'warn', a warning is raised if + the rotation is improper. If 'ignore', the reflection information is + discarded. If 'reflection' or 'inversion', additional information is + returned in the form of a boolean tensor indicating if the rotation + is improper. + If 'reflection', the boolean tensor indicates if the rotation contains + a reflection about a plane perpendicular to the rotation axis. + If 'inversion', the boolean tensor indicates if the rotation contains + an inversion of the coordinate system. + The quaternion is adjusted to represent the rotation to be performed + before the reflection or inversion. Returns ------- rotvec Shape (..., 3), depends on shape of inputs used for initialization. + (optional) reflection (if improper is 'reflection') or inversion (if improper is 'inversion') + boolean tensor of shape (...,), indicating if the rotation is improper + and if a reflection or inversion should be performed after the rotation. + References ---------- .. [ROTc] Rotation vector https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Rotation_vector """ - quaternions: torch.Tensor = self._quaternions - quaternions = _canonical_quaternion(quaternions) # w > 0 ensures that 0 <= angle <= pi - + if improper == 'reflection' or improper == 'inversion': + quaternions, is_improper = self.as_quat(canonical=True, improper=improper) + else: + quaternions, is_improper = self.as_quat(canonical=True, improper=improper), None angles = 2 * torch.atan2(torch.linalg.vector_norm(quaternions[..., :3], dim=-1), quaternions[..., 3]) scales = 2 / (torch.special.sinc(angles / (2 * torch.pi))) - rotvec = scales[..., None] * quaternions[..., :3] - if degrees: rotvec = torch.rad2deg(rotvec) + if is_improper is not None: + return rotvec, is_improper + else: + return rotvec - if self._single: - rotvec = rotvec[0] - - return rotvec - - def as_euler(self, seq: str, degrees: bool = False) -> torch.Tensor: + @overload + def as_euler( + self, + seq: str, + degrees: bool = ..., + *, + improper: Literal['ignore'] | Literal['warn'] = 'warn', + ) -> torch.Tensor: ... + @overload + def as_euler( + self, + seq: str, + degrees: bool = ..., + *, + improper: Literal['reflection'] | Literal['inversion'], + ) -> tuple[torch.Tensor, torch.Tensor]: ... + def as_euler( + self, + seq: str, + degrees: bool = False, + *, + improper: Literal['reflection'] | Literal['inversion'] | Literal['ignore'] | Literal['warn'] = 'warn', + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Represent as Euler angles. Any orientation can be expressed as a composition of 3 elementary @@ -609,6 +1040,18 @@ def as_euler(self, seq: str, degrees: bool = False) -> torch.Tensor: degrees Returned angles are in degrees if this flag is True, else they are in radians + improper + How to handle improper rotations. If 'warn', a warning is raised if + the rotation is improper. If 'ignore', the reflection information is + discarded. If 'reflection' or 'inversion', additional information is + returned in the form of a boolean tensor indicating if the rotation + is improper. + If 'reflection', the boolean tensor indicates if the rotation contains + a reflection about a plane perpendicular to the rotation axis. + If 'inversion', the boolean tensor indicates if the rotation contains + an inversion of the coordinate system. + The quaternion is adjusted to represent the rotation to be performed + before the reflection or inversion. Returns ------- @@ -642,8 +1085,11 @@ def as_euler(self, seq: str, degrees: bool = False) -> torch.Tensor: raise ValueError('Expected consecutive axes to be different, ' f'got {seq}') seq = seq.lower() + if improper == 'reflection' or improper == 'inversion': + quat, is_improper = self.as_quat(improper=improper) + else: + quat, is_improper = self.as_quat(improper=improper), None - quat = self.as_quat() if quat.ndim == 1: quat = quat[None, :] @@ -651,7 +1097,12 @@ def as_euler(self, seq: str, degrees: bool = False) -> torch.Tensor: if degrees: angles = torch.rad2deg(angles) - return angles[0] if self._single else angles + angles_ = angles[0] if self._single else angles + + if is_improper is not None: + return angles_, is_improper + else: + return angles_ def as_davenport(self, axes: torch.Tensor, order: str, degrees: bool = False) -> torch.Tensor: """Not implemented.""" @@ -679,7 +1130,60 @@ def concatenate(cls, rotations: Sequence[Rotation]) -> Self: raise TypeError('input must contain Rotation objects only') quats = torch.cat([torch.atleast_2d(x.as_quat()) for x in rotations]) - return cls(quats, normalize=False) + inversions = torch.cat([torch.atleast_1d(x._is_improper) for x in rotations]) + return cls(quats, normalize=False, copy=False, inversion=inversions, reflection=False) + + @overload + def apply(self, fn: NestedSequence[float] | torch.Tensor, inverse: bool) -> torch.Tensor: ... + + @overload + def apply( + self, fn: SpatialDimension[torch.Tensor] | SpatialDimension[float], inverse: bool + ) -> SpatialDimension[torch.Tensor]: ... + + @overload + def apply(self, fn: Callable[[torch.nn.Module], None]) -> Self: ... + + def apply( + self, + fn: NestedSequence[float] + | torch.Tensor + | SpatialDimension[torch.Tensor] + | SpatialDimension[float] + | Callable[[torch.nn.Module], None], + inverse: bool = False, + ) -> torch.Tensor | SpatialDimension[torch.Tensor] | Self: + """Either apply a function to the Rotation module or apply the rotation to a vector. + + This is a hybrid method that matches the signature of both `torch.nn.Module.apply` and + `scipy.spatial.transform.Rotation.apply`. + If a callable is passed, it is assumed to be a function that will be applied to the Rotation module. + For applying the rotation to a vector, consider using `Rotation(vector)` instead of `Rotation.apply(vector)`. + """ + if callable(fn): + # torch.nn.Module.apply + return super().apply(fn) + else: + # scipy.spatial.transform.Rotation.apply + warnings.warn('Consider using Rotation(vector) instead of Rotation.apply(vector).', stacklevel=2) + return self(fn, inverse) + + @overload + def __call__(self, vectors: NestedSequence[float] | torch.Tensor, inverse: bool = False) -> torch.Tensor: ... + + @overload + def __call__( + self, vectors: SpatialDimension[torch.Tensor] | SpatialDimension[float], inverse: bool = False + ) -> SpatialDimension[torch.Tensor]: ... + + def __call__( + self, + vectors: NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float], + inverse: bool = False, + ) -> torch.Tensor | SpatialDimension[torch.Tensor]: + """Apply this rotation to a set of vectors.""" + # Only for type hinting + return super().__call__(vectors, inverse) def forward( self, @@ -767,6 +1271,7 @@ def random( cls, num: int | Sequence[int] | None = None, random_state: int | np.random.RandomState | np.random.Generator | None = None, + improper: bool | Literal['random'] = False, ): """Generate uniformly distributed rotations. @@ -782,6 +1287,9 @@ def random( seeded with `random_state`. If `random_state` is already a ``Generator`` or ``RandomState`` instance then that instance is used. + improper + if True, only improper rotations are generated. If False, only proper rotations are generated. + if "random", then a random mix of proper and improper rotations are generated. Returns ------- @@ -797,8 +1305,58 @@ def random( random_sample = torch.as_tensor(generator.normal(size=(num, 4)), dtype=torch.float32) else: random_sample = torch.as_tensor(generator.normal(size=(*num, 4)), dtype=torch.float32) + if improper == 'random': + inversion: torch.Tensor | bool = torch.as_tensor( + generator.choice([True, False], size=random_sample.shape[:-1]), dtype=torch.bool + ) + elif isinstance(improper, bool): + inversion = improper + else: + raise ValueError('improper should be a boolean or "random"') + return cls(random_sample, inversion=inversion, reflection=False, normalize=True, copy=False) + + @classmethod + def random_vmf( + cls, + num: int | None = None, + mean_axis: torch.Tensor | None = None, + kappa: float = 0.0, + sigma: float = math.inf, + ): + """ + Randomly sample rotations from a von Mises-Fisher distribution. + + Generate rotations from a von Mises-Fisher distribution with a given mean axis and concentration parameter + and a 2pi-wrapped Gaussian distribution for the rotation angle. + + Parameters + ---------- + mean_axis + shape (..., 3,), the mean axis of the von Mises-Fisher distribution. + kappa + The concentration parameter of the von Mises-Fisher distribution. + small kappa results in a uniform distribution, large kappa results in a peak around the mean axis. + similar to the inverse of the variance of a Gaussian distribution. + sigma + Standard deviation (radians) of the 2pi-wrapped Gaussian distribution used to sample the rotation angle. + Use `math.inf` if a uniform distribution is desired. + num + number of samples to generate. If None, a single rotation is generated. - return cls(random_sample) + Returns + ------- + random_rotation + a stack of `(num, ...)` rotations. + + """ + n = 1 if num is None else num + mu = torch.tensor((1.0, 0.0, 0.0)) if mean_axis is None else torch.as_tensor(mean_axis) + rot_axes = sample_vmf(mu=mu, kappa=kappa, n_samples=n) + if sigma == math.inf: + rot_angle = torch.rand(n, *mu.shape[:-1], dtype=mu.dtype, device=mu.device) * 2 * math.pi + else: + rot_angle = (torch.randn(n, *mu.shape[:-1], dtype=mu.dtype, device=mu.device) * sigma) % (2 * math.pi) + return cls.from_rotvec(rot_axes * rot_angle.unsqueeze(-1)) def __mul__(self, other: Rotation) -> Self: """For compatibility with sp.spatial.transform.Rotation.""" @@ -835,13 +1393,19 @@ def __matmul__(self, other: Rotation) -> Self: rotation ``p[i]`` is composed with the corresponding rotation ``q[i]`` and `output` contains ``N`` rotations. """ + if not isinstance(other, Rotation): + return NotImplemented # type: ignore[unreachable] + p = self._quaternions q = other._quaternions - # TODO: broadcasting - result = _compose_quaternions(p, q) + p, q = torch.broadcast_tensors(p, q) + result_quaternions = _compose_quaternions(p, q) + result_improper = self._is_improper ^ other._is_improper + if self._single and other._single: - result = result[0] - return self.__class__(result, normalize=True, copy=False) + result_quaternions = result_quaternions[0] + result_improper = result_improper[0] + return self.__class__(result_quaternions, normalize=True, copy=False, inversion=result_improper) def __pow__(self, n: float, modulus: None = None): """Compose this rotation with itself `n` times. @@ -878,6 +1442,13 @@ def __pow__(self, n: float, modulus: None = None): then the identity rotation is returned, and if ``n == -1`` then ``p.inv()`` is returned. + For improper rotations, the power of a rotation with a reflection is + equivalent to the power of the rotation without the reflection, followed + by an reflection if the power is integer and odd. If the power is + non-integer, the reflection is never applied. + This means that, for example a 0.5 power of a rotation with a reflection + applied twice will result in a rotation without a reflection. + Note that fractional powers ``n`` which effectively take a root of rotation, do so using the shortest path smallest representation of that angle (the principal root). This means that powers of ``n`` and ``1/n`` @@ -896,11 +1467,15 @@ def __pow__(self, n: float, modulus: None = None): return self.inv() elif n == 1: if self._single: - return self.__class__(self._quaternions[0], copy=True) + return self.__class__(self._quaternions[0], inversion=self._is_improper[0], copy=True) else: - return self.__class__(self._quaternions, copy=True) - else: # general scaling of rotation angle - return Rotation.from_rotvec(n * self.as_rotvec()) + return self.__class__(self._quaternions, inversion=self._is_improper[0], copy=True) + elif math.isclose(round(n), n) and round(n) % 2: + improper: torch.Tensor | bool = self._is_improper + else: + improper = False + + return Rotation.from_rotvec(n * self.as_rotvec(), reflection=improper) def inv(self) -> Self: """Invert this rotation. @@ -914,9 +1489,55 @@ def inv(self) -> Self: Object containing inverse of the rotations in the current instance. """ quaternions = self._quaternions * torch.tensor([-1, -1, -1, 1]) + improper = self._is_improper.clone() + + if self._single: + quaternions = quaternions[0] + improper = self._is_improper[0] + + return self.__class__(quaternions, inversion=improper, copy=False) + + def reflect(self) -> Self: + """Reflect this rotation. + + Converts a proper rotation to an improper one, or vice versa + by reflecting the rotation about a plane perpendicular to the rotation axis. + + Returns + ------- + reflected + Object containing the reflected rotations. + """ + if self._single: + quaternions = self._quaternions[0] + is_improper = self._is_improper[0] + else: + quaternions = self._quaternions + is_improper = self._is_improper + + return self.__class__(quaternions, copy=False, inversion=is_improper, reflection=True) + + def invert_axes(self) -> Self: + """Invert the axes of the coordinate system. + + Converts a proper rotation to an improper one, or vice versa + by inversion of the coordinate system. + + Note: + This is not the same as the inverse of the rotation. + See `inv` for that. + + Returns + ------- + inverted_axes + Object containing the rotation with inverted axes. + """ + quaternions = self._quaternions.clone() + improper = ~self._is_improper if self._single: quaternions = quaternions[0] - return self.__class__(quaternions, copy=False) + improper = improper[0] + return self.__class__(quaternions, copy=False, inversion=improper) def magnitude(self) -> torch.Tensor: """Get the magnitude(s) of the rotation(s). @@ -960,9 +1581,9 @@ def approx_equal(self, other: Rotation, atol: float = 1e-6, degrees: bool = Fals if degrees: atol = np.deg2rad(atol) angles = (other @ self.inv()).magnitude() - return angles < atol + return (angles < atol) & (self._is_improper == other._is_improper) - def __getitem__(self, indexer: IndexerType) -> Self: + def __getitem__(self, indexer: TorchIndexerType) -> Self: """Extract rotation(s) at given index(es) from object. Create a new `Rotation` instance containing a subset of rotations @@ -984,10 +1605,10 @@ def __getitem__(self, indexer: IndexerType) -> Self: if self._single: raise TypeError('Single rotation is not subscriptable.') if isinstance(indexer, tuple): - _indexer = (*indexer, slice(None)) + indexer_quat = (*indexer, slice(None)) else: - _indexer = (indexer, slice(None)) - return self.__class__(self._quaternions[_indexer], normalize=False) + indexer_quat = (indexer, slice(None)) + return self.__class__(self._quaternions[indexer_quat], normalize=False, inversion=self._is_improper[indexer]) @property def quaternion_x(self) -> torch.Tensor: @@ -1045,7 +1666,7 @@ def quaternion_w(self, quat_w: torch.Tensor | float): axis = QUAT_AXIS_ORDER.index('w') self._quaternions[..., axis] = quat_w - def __setitem__(self, indexer: IndexerType, value: Rotation): + def __setitem__(self, indexer: TorchIndexerType, value: Rotation): """Set rotation(s) at given index(es) from object. Parameters @@ -1066,11 +1687,12 @@ def __setitem__(self, indexer: IndexerType, value: Rotation): raise TypeError('value must be a Rotation object') if isinstance(indexer, tuple): - _indexer = (*indexer, slice(None)) + indexer_quat = (*indexer, slice(None)) else: - _indexer = (indexer, slice(None)) - - self._quaternions[_indexer] = value.as_quat() + indexer_quat = (indexer, slice(None)) + quat, inversion = value.as_quat(improper='inversion') + self._quaternions[indexer_quat] = quat + self._is_improper[indexer] = inversion @classmethod def identity(cls, shape: int | None | tuple[int, ...] = None) -> Self: @@ -1108,6 +1730,7 @@ def align_vectors( weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, *, return_sensitivity: Literal[False] = False, + allow_improper: bool = ..., ) -> tuple[Rotation, float]: ... @overload @@ -1119,6 +1742,7 @@ def align_vectors( weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, *, return_sensitivity: Literal[True], + allow_improper: bool = ..., ) -> tuple[Rotation, float, torch.Tensor]: ... @classmethod @@ -1129,43 +1753,104 @@ def align_vectors( weights: torch.Tensor | Sequence[float] | Sequence[Sequence[float]] | None = None, *, return_sensitivity: bool = False, + allow_improper: bool = False, ) -> tuple[Rotation, float] | tuple[Rotation, float, torch.Tensor]: - """Estimate a rotation to optimally align two sets of vectors. + R"""Estimate a rotation to optimally align two sets of vectors. + + Find a rotation between frames A and B which best aligns a set of + vectors `a` and `b` observed in these frames. The following loss + function is minimized to solve for the rotation matrix :math:`R`: + + .. math:: + + L(R) = \\frac{1}{2} \\sum_{i = 1}^{n} w_i \\lVert \\mathbf{a}_i - + R \\mathbf{b}_i \\rVert^2 , + + where :math:`w_i`'s are the `weights` corresponding to each vector. + + The rotation is estimated with Kabsch algorithm [1]_, and solves what + is known as the "pointing problem", or "Wahba's problem" [2]_. + + There are two special cases. The first is if a single vector is given + for `a` and `b`, in which the shortest distance rotation that aligns + `b` to `a` is returned. The second is when one of the weights is infinity. + In this case, the shortest distance rotation between the primary infinite weight + vectors is calculated as above. Then, the rotation about the aligned primary + vectors is calculated such that the secondary vectors are optimally + aligned per the above loss function. The result is the composition + of these two rotations. The result via this process is the same as the + Kabsch algorithm as the corresponding weight approaches infinity in + the limit. For a single secondary vector this is known as the + "align-constrain" algorithm [3]_. + + For both special cases (single vectors or an infinite weight), the + sensitivity matrix does not have physical meaning and an error will be + raised if it is requested. For an infinite weight, the primary vectors + act as a constraint with perfect alignment, so their contribution to + `rssd` will be forced to 0 even if they are of different lengths. + + Parameters + ---------- + a + Vector components observed in initial frame A. Each row of `a` + denotes a vector. + b + Vector components observed in another frame B. Each row of `b` + denotes a vector. + weights + Weights describing the relative importance of the vector + observations. If None (default), then all values in `weights` are + assumed to be 1. One and only one weight may be infinity, and + weights must be positive. + return_sensitivity + Whether to return the sensitivity matrix. + allow_improper + If True, allow improper rotations to be returned. If False (default), + then the rotation is restricted to be proper. + + Returns + ------- + rotation + Best estimate of the rotation that transforms `b` to `a`. + rssd + Square root of the weighted sum of the squared distances between the given sets of + vectors + after alignment. + sensitivity_matrix + Sensitivity matrix of the estimated rotation estimate as explained + in Notes. - For more information, see `scipy.spatial.transform.Rotation.align_vectors`. - This will move to cpu, invoke scipy, convert to tensor, move back to device of a. + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Kabsch_algorithm + .. [2] https://en.wikipedia.org/wiki/Wahba%27s_problem + .. [3] Magner, Robert, + "Extending target tracking capabilities through trajectory and + momentum setpoint optimization." Small Satellite Conference, + 2018. """ a_tensor = torch.stack([torch.as_tensor(el) for el in a]) if isinstance(a, Sequence) else torch.as_tensor(a) - a_np = a_tensor.numpy(force=True) - b_tensor = torch.stack([torch.as_tensor(el) for el in b]) if isinstance(b, Sequence) else torch.as_tensor(b) - b_np = b_tensor.numpy(force=True) - dtype = torch.promote_types(a_tensor.dtype, b_tensor.dtype) if not dtype.is_floating_point: # boolean or integer inputs will result in float32 dtype = torch.float32 - + a_tensor = torch.atleast_2d(a_tensor).to(dtype=dtype) + b_tensor = torch.atleast_2d(b_tensor).to(dtype=dtype) if weights is None: - weights_np = None - elif isinstance(weights, torch.Tensor): - weights_np = weights.numpy(force=True) + weights_tensor = a_tensor.new_ones(a_tensor.shape[:-1], dtype=dtype) else: - weights_np = np.asarray(weights) + weights_tensor = torch.atleast_1d(torch.as_tensor(weights, dtype=dtype)) - if return_sensitivity: - rotation_sp, rssd, sensitivity_np = Rotation_scipy.align_vectors(a_np, b_np, weights_np, True) - sensitivity = torch.as_tensor(sensitivity_np, dtype=dtype) - else: - rotation_sp, rssd = Rotation_scipy.align_vectors(a_np, b_np, weights_np, False) - - quat_np = rotation_sp.as_quat() - quat = torch.as_tensor(quat_np, device=a_tensor.device, dtype=dtype) + if a_tensor.ndim > 2 or b_tensor.ndim > 2 or weights_tensor.ndim > 1: + raise NotImplementedError('Batched inputs are not supported.') if return_sensitivity: - return (cls(quat), float(rssd), sensitivity) + rot_matrix, rssd, sensitivity = _align_vectors(a_tensor, b_tensor, weights_tensor, True, allow_improper) + return cls.from_matrix(rot_matrix), rssd, sensitivity else: - return (cls(quat), float(rssd)) + rot_matrix, rssd = _align_vectors(a_tensor, b_tensor, weights_tensor, False, allow_improper) + return cls.from_matrix(rot_matrix), rssd @property def shape(self) -> torch.Size: @@ -1190,10 +1875,16 @@ def __len__(self) -> int: def __repr__(self): """Return String Representation of the Rotation.""" - if self._single: + if self._single and not self._is_improper: return f'Rotation({self._quaternions.tolist()})' + elif self._single and self._is_improper: + return f'improper Rotation({self._quaternions.tolist()})' + elif self._is_improper.all(): + return f'{tuple(self.shape)}-batched improper Rotation()' + elif self._is_improper.any(): + return f'{tuple(self.shape)}-batched (mixed proper/improper) Rotation()' else: - return f'{tuple(self.shape)}-Batched Rotation()' + return f'{tuple(self.shape)}-batched Rotation()' def mean( self, @@ -1214,6 +1905,11 @@ def mean( Optionally, if A is a set of Rotation matrices with multiple batch dimensions, the dimensions to reduce over can be specified. + If the rotations contains improper, the mean will be computed without + considering the improper and the result will contain a reflection if + the weighted majority of the rotations over which the mean is taken + have improper. + Parameters ---------- weights @@ -1247,11 +1943,18 @@ def mean( if torch.any(weights < 0): raise ValueError('`weights` must be non-negative.') + if isinstance(dim, Sequence): + dim = tuple(dim) + + modal_improper = (weights * self._is_improper).sum(dim=dim, keepdim=keepdim) > 0.5 * weights.sum( + dim=dim, keepdim=keepdim + ) + quaternions = torch.as_tensor(self._quaternions) if dim is None: quaternions = quaternions.reshape(-1, 4) weights = weights.reshape(-1) - dim = range(len(self.shape)) + dim = list(range(len(self.shape))) else: dim = ( [d % (quaternions.ndim - 1) for d in dim] @@ -1269,4 +1972,28 @@ def mean( # unsqueeze the dimensions we removed in the reshape and product for d in sorted(dim): mean_quaternions = mean_quaternions.unsqueeze(d) - return self.__class__(mean_quaternions, normalize=False) + + return self.__class__(mean_quaternions, inversion=modal_improper, normalize=False) + + def reshape(self, *shape: int | Sequence[int]) -> Self: + """Reshape the Rotation object in the batch dimensions. + + Parameters + ---------- + shape + The new shape of the Rotation object. + + Returns + ------- + reshaped + The reshaped Rotation object. + """ + newshape = [] + for s in shape: + if isinstance(s, int): + newshape.append(s) + else: + newshape.extend(s) + return self.__class__( + self._quaternions.reshape(*newshape, 4), inversion=self._is_improper.reshape(newshape), copy=True + ) diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index aaea2d5e5..12f94e8a6 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -4,15 +4,40 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Generic, Protocol, TypeVar +from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike +from typing_extensions import Protocol, Self, TypeVar, overload +import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin -T = TypeVar('T', int, float, torch.Tensor) +# Change here to add more types +VectorTypes = torch.Tensor +ScalarTypes = int | float +T = TypeVar('T', torch.Tensor, int, float) + +# Covariant types, as SpatialDimension is a Container +# and we want, for example, SpatialDimension[int] to also be a SpatialDimension[float] +T_co = TypeVar('T_co', torch.Tensor, int, float, covariant=True) +T_co_float = TypeVar('T_co_float', float, torch.Tensor, covariant=True) +T_co_vector = torch.Tensor +T_co_scalar = TypeVar('T_co_scalar', int, float, covariant=True) + + +def _as_vectortype(x: ArrayLike) -> VectorTypes: + """Convert ArrayLike to VectorType.""" + if isinstance(x, VectorTypes) and type(x) in get_args(VectorTypes): + # exact type match + return x + if isinstance(x, VectorTypes): + # subclass of torch.Tensor + return torch.as_tensor(x) + else: + # any other ArrayLike (which is defined as convert to numpy array) + return torch.as_tensor(np.asarray(x)) class XYZ(Protocol[T]): @@ -24,32 +49,27 @@ class XYZ(Protocol[T]): @dataclass(slots=True) -class SpatialDimension(MoveDataMixin, Generic[T]): +class SpatialDimension(MoveDataMixin, Generic[T_co]): """Spatial dataclass of float/int/tensors (z, y, x).""" - z: T - y: T - x: T + z: T_co + y: T_co + x: T_co @classmethod - def from_xyz(cls, data: XYZ[T], conversion: Callable[[T], T] | None = None) -> SpatialDimension[T]: + def from_xyz(cls, data: XYZ[T_co]) -> SpatialDimension[T_co]: """Create a SpatialDimension from something with (.x .y .z) parameters. Parameters ---------- data should implement .x .y .z. For example ismrmrd's matrixSizeType. - conversion, optional - will be called for each value to convert it """ - if conversion is not None: - return cls(conversion(data.z), conversion(data.y), conversion(data.x)) return cls(data.z, data.y, data.x) @staticmethod def from_array_xyz( data: ArrayLike, - conversion: Callable[[torch.Tensor], torch.Tensor] | None = None, ) -> SpatialDimension[torch.Tensor]: """Create a SpatialDimension from an arraylike interface. @@ -57,29 +77,20 @@ def from_array_xyz( ---------- data shape (..., 3) in the order (x,y,z) - conversion - will be called for each value to convert it """ - if not isinstance(data, np.ndarray | torch.Tensor): - data = np.asarray(data) - - if np.size(data, -1) != 3: - raise ValueError(f'Expected last dimension to be 3, got {np.size(data, -1)}') + data_ = _as_vectortype(data) + if np.size(data_, -1) != 3: + raise ValueError(f'Expected last dimension to be 3, got {np.size(data_, -1)}') - x = torch.as_tensor(data[..., 0]) - y = torch.as_tensor(data[..., 1]) - z = torch.as_tensor(data[..., 2]) + x = data_[..., 0] + y = data_[..., 1] + z = data_[..., 2] - if conversion is not None: - x = conversion(x) - y = conversion(y) - z = conversion(z) return SpatialDimension(z, y, x) @staticmethod def from_array_zyx( data: ArrayLike, - conversion: Callable[[torch.Tensor], torch.Tensor] | None = None, ) -> SpatialDimension[torch.Tensor]: """Create a SpatialDimension from an arraylike interface. @@ -87,17 +98,332 @@ def from_array_zyx( ---------- data shape (..., 3) in the order (z,y,x) - conversion - will be called for each value to convert it """ - data = torch.flip(torch.as_tensor(data), (-1,)) - return SpatialDimension.from_array_xyz(data, conversion) + data_ = _as_vectortype(data) + if np.size(data_, -1) != 3: + raise ValueError(f'Expected last dimension to be 3, got {np.size(data_, -1)}') + + x = data_[..., 2] + y = data_[..., 1] + z = data_[..., 0] + + return SpatialDimension(z, y, x) + + # This function is mainly for type hinting and docstring + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply_(function) + + # This function is mainly for type hinting and docstring + def apply(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (returning a new object). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply(function) @property - def zyx(self) -> tuple[T, T, T]: + def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" return (self.z, self.y, self.x) def __str__(self) -> str: """Return a string representation of the SpatialDimension.""" return f'z={self.z}, y={self.y}, x={self.x}' + + def __getitem__( + self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexerType + ) -> SpatialDimension[T_co_vector]: + """Get SpatialDimension item.""" + if not all(isinstance(el, VectorTypes) for el in self.zyx): + raise IndexError('Cannot index SpatialDimension with non-indexable members') + return SpatialDimension(self.z[idx], self.y[idx], self.x[idx]) + + def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexerType, other: SpatialDimension): + """Set SpatialDimension item.""" + if not all(isinstance(el, VectorTypes) for el in self.zyx): + raise IndexError('Cannot index SpatialDimension with non-indexable members') + self.z[idx] = other.z + self.y[idx] = other.y + self.x[idx] = other.x + + @overload + def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... + + @overload + def __mul__(self: SpatialDimension, other: SpatialDimension[T_co_vector]) -> SpatialDimension[T_co_vector]: ... + + @overload + def __mul__(self: SpatialDimension[int], other: float | SpatialDimension[float]) -> SpatialDimension[float]: ... + + @overload + def __mul__( + self: SpatialDimension[T_co_float], other: float | SpatialDimension[float] + ) -> SpatialDimension[T_co_float]: ... + + def __mul__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Multiply SpatialDimension with numeric other or SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(self.z * other.z, self.y * other.y, self.x * other.x) + return SpatialDimension(self.z * other, self.y * other, self.x * other) + + @overload + def __rmul__(self: SpatialDimension[T_co], other: T_co) -> SpatialDimension[T_co]: ... + @overload + def __rmul__(self: SpatialDimension[int], other: float) -> SpatialDimension[float]: ... + + @overload + def __rmul__(self: SpatialDimension[T_co_float], other: float) -> SpatialDimension[T_co_float]: ... + + def __rmul__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Right multiply SpatialDimension with numeric other or SpatialDimension.""" + return self.__mul__(other) + + @overload + def __truediv__(self: SpatialDimension[int], other: float | SpatialDimension[float]) -> SpatialDimension[float]: ... + + @overload + def __truediv__(self: SpatialDimension, other: SpatialDimension[T_co_vector]) -> SpatialDimension[T_co_vector]: ... + + @overload + def __truediv__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... + + @overload + def __truediv__( + self: SpatialDimension[T_co_float], other: float | SpatialDimension[float] + ) -> SpatialDimension[T_co_float]: ... + + def __truediv__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Divide SpatialDimension by numeric other or SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(self.z / other.z, self.y / other.y, self.x / other.x) + return SpatialDimension(self.z / other, self.y / other, self.x / other) + + @overload + def __rtruediv__(self: SpatialDimension[int], other: float) -> SpatialDimension[float]: ... + @overload + def __rtruediv__(self: SpatialDimension[T_co], other: T_co) -> SpatialDimension[T_co]: ... + + @overload + def __rtruediv__(self: SpatialDimension[T_co_float], other: float) -> SpatialDimension[T_co_float]: ... + + def __rtruediv__(self: SpatialDimension[T_co], other: float | T_co) -> SpatialDimension: + """Divide SpatialDimension or numeric other by SpatialDimension.""" + return SpatialDimension(other / self.z, other / self.y, other / self.x) + + @overload + def __add__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... + + @overload + def __add__(self: SpatialDimension, other: SpatialDimension[T_co_vector]) -> SpatialDimension[T_co_vector]: ... + + @overload + def __add__(self: SpatialDimension[int], other: float | SpatialDimension[float]) -> SpatialDimension[float]: ... + + @overload + def __add__( + self: SpatialDimension[T_co_float], other: float | SpatialDimension[float] + ) -> SpatialDimension[T_co_float]: ... + + def __add__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Add SpatialDimension or numeric other to SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(self.z + other.z, self.y + other.y, self.x + other.x) + return SpatialDimension(self.z + other, self.y + other, self.x + other) + + @overload + def __radd__(self: SpatialDimension[T_co], other: T_co) -> SpatialDimension[T_co]: ... + + @overload + def __radd__(self: SpatialDimension[int], other: float) -> SpatialDimension[float]: ... + + @overload + def __radd__(self: SpatialDimension[T_co_float], other: float) -> SpatialDimension[T_co_float]: ... + + def __radd__(self: SpatialDimension[T_co], other: float | T_co) -> SpatialDimension: + """Right add numeric other to SpatialDimension.""" + return self.__add__(other) + + @overload + def __floordiv__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... + + @overload + def __floordiv__( + self: SpatialDimension[int], other: float | SpatialDimension[float] + ) -> SpatialDimension[float]: ... + + @overload + def __floordiv__(self: SpatialDimension, other: SpatialDimension[T_co_vector]) -> SpatialDimension[T_co_vector]: ... + + @overload + def __floordiv__( + self: SpatialDimension[T_co_float], other: float | SpatialDimension[float] + ) -> SpatialDimension[T_co_float]: ... + + def __floordiv__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Floor divide SpatialDimension by numeric other.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(self.z // other.z, self.y // other.y, self.x // other.x) + return SpatialDimension(self.z // other, self.y // other, self.x // other) + + @overload + def __rfloordiv__(self: SpatialDimension[T_co], other: T_co) -> SpatialDimension[T_co]: ... + + @overload + def __rfloordiv__(self: SpatialDimension[int], other: float) -> SpatialDimension[float]: ... + + @overload + def __rfloordiv__(self: SpatialDimension[T_co_float], other: float) -> SpatialDimension[T_co_float]: ... + + def __rfloordiv__(self: SpatialDimension[T_co], other: float | T_co) -> SpatialDimension: + """Floor divide other by SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(other.z // self.z, other.y // self.y, other.x // self.x) + return SpatialDimension(other // self.z, other // self.y, other // self.x) + + @overload + def __sub__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... + + @overload + def __sub__(self: SpatialDimension, other: SpatialDimension[T_co_vector]) -> SpatialDimension[T_co_vector]: ... + + @overload + def __sub__(self: SpatialDimension[int], other: float | SpatialDimension[float]) -> SpatialDimension[float]: ... + + @overload + def __sub__( + self: SpatialDimension[T_co_float], other: float | SpatialDimension[float] + ) -> SpatialDimension[T_co_float]: ... + + def __sub__(self: SpatialDimension[T_co], other: float | T_co | SpatialDimension) -> SpatialDimension: + """Subtract SpatialDimension or numeric other to SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(self.z - other.z, self.y - other.y, self.x - other.x) + return SpatialDimension(self.z - other, self.y - other, self.x - other) + + @overload + def __rsub__(self: SpatialDimension[T_co], other: T_co) -> SpatialDimension[T_co]: ... + + @overload + def __rsub__(self: SpatialDimension[int], other: float) -> SpatialDimension[float]: ... + + @overload + def __rsub__(self: SpatialDimension[T_co_float], other: float) -> SpatialDimension[T_co_float]: ... + + def __rsub__(self: SpatialDimension[T_co], other: float | T_co) -> SpatialDimension: + """Subtract SpatialDimension from numeric other or SpatialDimension.""" + if isinstance(other, SpatialDimension): + return SpatialDimension(other.z - self.z, other.y - self.y, other.x - self.x) + return SpatialDimension(other - self.z, other - self.y, other - self.x) + + def __neg__(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: + """Negate SpatialDimension.""" + return SpatialDimension(-self.z, -self.y, -self.x) + + @overload + def __eq__(self: SpatialDimension[T_co_scalar], other: object) -> bool: ... + @overload + def __eq__(self: SpatialDimension[T_co_vector], other: SpatialDimension[T_co_vector]) -> T_co_vector: ... + + def __eq__( + self: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + other: object | SpatialDimension[T_co_vector], + ) -> bool | T_co_vector: + """Check if self is equal to other.""" + if not isinstance(other, SpatialDimension): + return NotImplemented + return (self.z == other.z) & (self.y == other.y) & (self.x == other.x) + + @overload + def __lt__(self: SpatialDimension[T_co_vector], other: SpatialDimension[T_co_vector]) -> T_co_vector: ... + @overload + def __lt__(self: SpatialDimension[T_co_scalar], other: SpatialDimension[T_co_scalar]) -> bool: ... + def __lt__( + self: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + other: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + ) -> bool | T_co_vector: + """Check if self is less than other.""" + if not isinstance(other, SpatialDimension): + return NotImplemented + return (self.x < other.x) & (self.y < other.y) & (self.z < other.z) + + @overload + def __le__(self: SpatialDimension[T_co_vector], other: SpatialDimension[T_co_vector]) -> T_co_vector: ... + @overload + def __le__(self: SpatialDimension[T_co_scalar], other: SpatialDimension[T_co_scalar]) -> bool: ... + def __le__( + self: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + other: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + ) -> bool | T_co_vector: + """Check if self is less than or equal to other.""" + if not isinstance(other, SpatialDimension): + return NotImplemented + return (self.x <= other.x) & (self.y <= other.y) & (self.z <= other.z) + + @overload + def __gt__(self: SpatialDimension[T_co_vector], other: SpatialDimension[T_co_vector]) -> T_co_vector: ... + @overload + def __gt__(self: SpatialDimension[T_co_scalar], other: SpatialDimension[T_co_scalar]) -> bool: ... + def __gt__( + self: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + other: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + ) -> bool | T_co_vector: + """Check if self is greater than other.""" + if not isinstance(other, SpatialDimension): + return NotImplemented + return (self.x > other.x) & (self.y > other.y) & (self.z > other.z) + + @overload + def __ge__(self: SpatialDimension[T_co_vector], other: SpatialDimension[T_co_vector]) -> T_co_vector: ... + @overload + def __ge__(self: SpatialDimension[T_co_scalar], other: SpatialDimension[T_co_scalar]) -> bool: ... + def __ge__( + self: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + other: SpatialDimension[T_co_scalar] | SpatialDimension[T_co_vector], + ) -> bool | T_co_vector: + """Check if self is greater than or equal to other.""" + if not isinstance(other, SpatialDimension): + return NotImplemented + return (self.x >= other.x) & (self.y >= other.y) & (self.z >= other.z) + + def __post_init__(self): + """Ensure that the data is of matching shape.""" + if not all(isinstance(val, (int | float)) for val in self.zyx): + try: + zyx = [_as_vectortype(v) for v in self.zyx] + self.z, self.y, self.x = torch.broadcast_tensors(*zyx) + except RuntimeError: + raise ValueError('The shapes of the tensors do not match') from None + + @property + def shape(self) -> tuple[int, ...]: + """Get the shape of the x, y, and z. + + Returns + ------- + Empty tuple if x, y, and z are scalar types, otherwise shape + + Raises + ------ + ValueError if the shapes are not equal + """ + if isinstance(self.x, ScalarTypes) and isinstance(self.y, ScalarTypes) and isinstance(self.z, ScalarTypes): + return () + elif ( + isinstance(self.x, VectorTypes) + and isinstance(self.y, VectorTypes) + and isinstance(self.z, VectorTypes) + and self.x.shape == self.y.shape == self.z.shape + ): + return self.x.shape + else: + raise ValueError('Inconsistent shapes') diff --git a/src/mrpro/data/TrajectoryDescription.py b/src/mrpro/data/TrajectoryDescription.py deleted file mode 100644 index f1a9cce5a..000000000 --- a/src/mrpro/data/TrajectoryDescription.py +++ /dev/null @@ -1,29 +0,0 @@ -"""TrajectoryDescription dataclass.""" - -import dataclasses -from dataclasses import dataclass -from typing import Self - -from ismrmrd.xsd.ismrmrdschema.ismrmrd import trajectoryDescriptionType - - -@dataclass(slots=True) -class TrajectoryDescription: - """TrajectoryDescription dataclass.""" - - identifier: str = '' - user_parameter_long: dict[str, int] = dataclasses.field(default_factory=dict) - user_parameter_double: dict[str, float] = dataclasses.field(default_factory=dict) - user_parameter_string: dict[str, str] = dataclasses.field(default_factory=dict) - comment: str = '' - - @classmethod - def from_ismrmrd(cls, trajectory_description: trajectoryDescriptionType) -> Self: - """Create TrajectoryDescription from ismrmrd traj description.""" - return cls( - user_parameter_long={p.name: int(p.value) for p in trajectory_description.userParameterLong}, - user_parameter_double={p.name: float(p.value) for p in trajectory_description.userParameterDouble}, - user_parameter_string={p.name: str(p.value) for p in trajectory_description.userParameterString}, - comment=trajectory_description.comment or '', - identifier=trajectory_description.identifier or '', - ) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index a954ca1b3..d5667a5bc 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -16,5 +16,27 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription -__all__ = ["enums", "traj_calculators", "acq_filters", "AcqIdx", "AcqInfo", "CsmData", "Data", "DcfData", "EncodingLimits", "Limits", "IData", "IHeader", "KData", "KHeader", "KNoise", "KTrajectory", "KTrajectoryRawShape", "MoveDataMixin", "QData", "QHeader", "Rotation", "SpatialDimension", "TrajectoryDescription"] +__all__ = [ + "AcqIdx", + "AcqInfo", + "CsmData", + "Data", + "DcfData", + "EncodingLimits", + "IData", + "IHeader", + "KData", + "KHeader", + "KNoise", + "KTrajectory", + "KTrajectoryRawShape", + "Limits", + "MoveDataMixin", + "QData", + "QHeader", + "Rotation", + "SpatialDimension", + "acq_filters", + "enums", + "traj_calculators" +] diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 7da9b4d1f..4b5df6250 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -3,30 +3,31 @@ import dataclasses import datetime import warnings -from collections.abc import Callable +from collections.abc import Callable, Sequence from pathlib import Path -from typing import Self +from types import EllipsisType import h5py import ismrmrd import numpy as np import torch from einops import rearrange +from typing_extensions import Self from mrpro.data._kdata.KDataRearrangeMixin import KDataRearrangeMixin from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin -from mrpro.data.acq_filters import is_image_acquisition -from mrpro.data.AcqInfo import AcqInfo +from mrpro.data.acq_filters import has_n_coils, is_image_acquisition +from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd -from mrpro.utils import modify_acq_info KDIM_SORT_LABELS = ( 'k1', @@ -110,6 +111,29 @@ def from_file( modification_time = datetime.datetime.fromtimestamp(mtime) acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)] + + # we need the same number of receiver coils for all acquisitions + n_coils_available = {acq.data.shape[0] for acq in acquisitions} + if len(n_coils_available) > 1: + if ( + ismrmrd_header.acquisitionSystemInformation is not None + and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None + ): + n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels) + else: + # most likely, highest number of elements are the coils used for imaging + n_coils = int(max(n_coils_available)) + + warnings.warn( + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. ' + f'Data with {n_coils} receiver coil elements will be used.', + stacklevel=1, + ) + acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] + + if not acquisitions: + raise ValueError('No acquisitions meeting the given filter criteria were found.') + kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) @@ -200,10 +224,13 @@ def from_file( sort_idx = np.lexsort(acq_indices) # torch does not have lexsort as of pytorch 2.2 (March 2024) # Finally, reshape and sort the tensors in acqinfo and acqinfo.idx, and kdata. - def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): - return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) - - kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2 + ) + if isinstance(field, torch.Tensor | Rotation) + else field + ) kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape @@ -250,3 +277,92 @@ def __repr__(self): f'{self.header}' ) return out + + def compress_coils( + self: Self, + n_compressed_coils: int, + batch_dims: None | Sequence[int] = None, + joint_dims: Sequence[int] | EllipsisType = ..., + ) -> Self: + """Reduce the number of coils based on a PCA compression. + + A PCA is carried out along the coil dimension and the n_compressed_coils virtual coil elements are selected. For + more information on coil compression please see [BUE2007]_, [DON2008]_ and [HUA2008]_. + + Returns a copy of the data. + + Parameters + ---------- + kdata + K-space data + n_compressed_coils + Number of compressed coils + batch_dims + Dimensions which are treated as batched, i.e. separate coil compression matrizes (e.g. different slices). + Default is to do one coil compression matrix for the entire k-space data. Only batch_dim or joint_dim can + be defined. If batch_dims is not None then joint_dims has to be ... + joint_dims + Dimensions which are combined to calculate single coil compression matrix (e.g. k0, k1, contrast). Default + is that all dimensions (except for the coil dimension) are joint_dims. Only batch_dim or joint_dim can + be defined. If joint_dims is not ... batch_dims has to be None + + Returns + ------- + Copy of K-space data with compressed coils. + + Raises + ------ + ValueError + If both batch_dims and joint_dims are defined. + Valuer Error + If coil dimension is part of joint_dims or batch_dims. + + References + ---------- + .. [BUE2007] Buehrer M, Pruessmann KP, Boesiger P, Kozerke S (2007) Array compression for MRI with large coil + arrays. MRM 57. https://doi.org/10.1002/mrm.21237 + .. [DON2008] Doneva M, Boernert P (2008) Automatic coil selection for channel reduction in SENSE-based parallel + imaging. MAGMA 21. https://doi.org/10.1007/s10334-008-0110-x + .. [HUA2008] Huang F, Vijayakumar S, Li Y, Hertel S, Duensing GR (2008) A software channel compression + technique for faster reconstruction with many channels. MRM 26. https://doi.org/10.1016/j.mri.2007.04.010 + + """ + from mrpro.operators import PCACompressionOp + + coil_dim = -4 % self.data.ndim + if batch_dims is not None and joint_dims is not Ellipsis: + raise ValueError('Either batch_dims or joint_dims can be defined not both.') + + if joint_dims is not Ellipsis: + joint_dims_normalized = [i % self.data.ndim for i in joint_dims] + if coil_dim in joint_dims_normalized: + raise ValueError('Coil dimension must not be in joint_dims') + batch_dims_normalized = [ + d for d in range(self.data.ndim) if d not in joint_dims_normalized and d is not coil_dim + ] + else: + batch_dims_normalized = [] if batch_dims is None else [i % self.data.ndim for i in batch_dims] + if coil_dim in batch_dims_normalized: + raise ValueError('Coil dimension must not be in batch_dims') + + # reshape to (*batch dimension, -1, coils) + permute_order = ( + batch_dims_normalized + + [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized] + + [coil_dim] + ) + kdata_coil_compressed = self.data.permute(permute_order) + permuted_kdata_shape = kdata_coil_compressed.shape + kdata_coil_compressed = kdata_coil_compressed.flatten( + start_dim=len(batch_dims_normalized), end_dim=-2 + ) # keep separate dimensions and coil + + pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils) + (kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed) + + # reshape to original dimensions and undo permutation + kdata_coil_compressed = torch.reshape( + kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils] + ).permute(*np.argsort(permute_order)) + + return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) diff --git a/src/mrpro/data/_kdata/KDataProtocol.py b/src/mrpro/data/_kdata/KDataProtocol.py index bbb969b49..485a8fc4d 100644 --- a/src/mrpro/data/_kdata/KDataProtocol.py +++ b/src/mrpro/data/_kdata/KDataProtocol.py @@ -1,8 +1,9 @@ """Protocol for KData.""" -from typing import Literal, Protocol, Self +from typing import Literal import torch +from typing_extensions import Protocol, Self from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py index b591b9838..23a58dea6 100644 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/KDataRearrangeMixin.py @@ -1,13 +1,12 @@ """Rearrange KData.""" import copy -from typing import Self from einops import rearrange +from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import AcqInfo -from mrpro.utils import modify_acq_info +from mrpro.data.AcqInfo import rearrange_acq_info_fields class KDataRearrangeMixin(_KDataProtocol): @@ -35,9 +34,8 @@ def rearrange_k2_k1_into_k1(self: Self) -> Self: kheader = copy.deepcopy(self.header) # Update shape of acquisition info index - def reshape_acq_info(info: AcqInfo): - return rearrange(info, 'other k2 k1 ... -> other 1 (k2 k1) ...') - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py index 3de4c0d32..555f56a39 100644 --- a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py +++ b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py @@ -1,9 +1,9 @@ """Remove oversampling along readout dimension.""" from copy import deepcopy -from typing import Self import torch +from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol from mrpro.data.KTrajectory import KTrajectory @@ -49,7 +49,7 @@ def remove_readout_os(self: Self) -> Self: start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2 end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x - def crop_readout(data_to_crop: torch.Tensor): + def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: # returns a cropped copy return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone() @@ -61,7 +61,7 @@ def crop_readout(data_to_crop: torch.Tensor): ks = [self.traj.kz, self.traj.ky, self.traj.kx] # only cropped ks that are not broadcasted/singleton along k0 cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks] - cropped_traj = KTrajectory(*cropped_ks) + cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2]) # Adapt header parameters header = deepcopy(self.header) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py index 43a2e2b2b..8f8a452cf 100644 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/KDataSelectMixin.py @@ -1,12 +1,13 @@ """Select subset along other dimensions of KData.""" import copy -from typing import Literal, Self +from typing import Literal import torch +from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation class KDataSelectMixin(_KDataProtocol): @@ -50,10 +51,9 @@ def select_other_subset( other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) # Adapt header - def select_acq_info(info: torch.Tensor): - return info[other_idx, ...] - - kheader.acq_info = modify_acq_info(select_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) # Select data kdat = self.data[other_idx, ...] diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py index a58b7181b..c28004af4 100644 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/KDataSplitMixin.py @@ -1,14 +1,17 @@ """Mixin class to split KData into other subsets.""" -import copy -from typing import Literal, Self +from typing import Literal, TypeVar, cast import torch from einops import rearrange, repeat +from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation + +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) class KDataSplitMixin(_KDataProtocol): @@ -55,8 +58,9 @@ def _split_k2_or_k1_into_other( def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, :, split_idx, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, :, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' @@ -68,8 +72,8 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, split_idx, :, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' @@ -92,13 +96,14 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) # Create new header with correct shape - kheader = copy.deepcopy(self.header) + kheader = self.header.clone() # Update shape of acquisition info index - def reshape_acq_info(info: torch.Tensor): - return rearrange(split_acq_info(info), rearrange_pattern_acq_info) - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) # Update other label limits and acquisition info setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) diff --git a/src/mrpro/data/acq_filters.py b/src/mrpro/data/acq_filters.py index d64c4d9a6..4723d3bba 100644 --- a/src/mrpro/data/acq_filters.py +++ b/src/mrpro/data/acq_filters.py @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool: """ coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING return coil_calibration_flag.value & acquisition.flags + + +def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool: + """Test if acquisitions was obtained with a certain number of receiver coils. + + Parameters + ---------- + n_coils + number of receiver coils + acquisition + ISMRMRD acquisition + + Returns + ------- + True if the acquisition was obtained with n_coils receiver coils + """ + return acquisition.data.shape[0] == n_coils diff --git a/src/mrpro/data/traj_calculators/__init__.py b/src/mrpro/data/traj_calculators/__init__.py index 72d61536b..2fd0e5b5a 100644 --- a/src/mrpro/data/traj_calculators/__init__.py +++ b/src/mrpro/data/traj_calculators/__init__.py @@ -5,4 +5,12 @@ from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd from mrpro.data.traj_calculators.KTrajectoryPulseq import KTrajectoryPulseq from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian -__all__ = ["KTrajectoryCalculator", "KTrajectoryRpe", "KTrajectorySunflowerGoldenRpe", "KTrajectoryRadial2D", "KTrajectoryIsmrmrd", "KTrajectoryPulseq", "KTrajectoryCartesian"] +__all__ = [ + "KTrajectoryCalculator", + "KTrajectoryCartesian", + "KTrajectoryIsmrmrd", + "KTrajectoryPulseq", + "KTrajectoryRadial2D", + "KTrajectoryRpe", + "KTrajectorySunflowerGoldenRpe" +] \ No newline at end of file diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..47c71c77f 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -1,5 +1,7 @@ """Cartesian Sampling Operator.""" +import warnings + import torch from einops import rearrange, repeat @@ -7,6 +9,7 @@ from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator +from mrpro.utils.reshape import unsqueeze_left class CartesianSamplingOp(LinearOperator): @@ -47,26 +50,53 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + + # check that all points are inside the encoding matrix + inside_encoding_matrix = ( + ((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x)) + & ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y)) + & ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z)) + ) + if not torch.all(inside_encoding_matrix): + warnings.warn( + 'K-space points lie outside of the encoding_matrix and will be ignored.' + ' Increase the encoding_matrix to include these points.', + stacklevel=2, + ) + + inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)') + inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1] + inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1)) + self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx) + kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1) + else: + self._inside_encoding_matrix_idx: torch.Tensor | None = None + self.register_buffer('_fft_idx', kidx) + # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) + or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + or self._inside_encoding_matrix_idx is not None + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape @@ -91,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (x,) x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') - # take_along_dim does broadcast, so no need for extending here - x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) + # take_along_dim broadcasts, but needs the same number of dimensions + idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim) + x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1) + + if self._inside_encoding_matrix_idx is None: + # all trajectory points are inside the encoding matrix + x_indexed = x_inside_encoding_matrix + else: + # we need to add zeros + x_indexed = self._broadcast_and_scatter_along_last_dim( + x_inside_encoding_matrix, + self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3], + self._inside_encoding_matrix_idx, + ) + # reshape to (... other coil, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) @@ -118,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') - # scatter does not broadcast, so we need to manually broadcast the indices - broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) - idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1])) + if self._inside_encoding_matrix_idx is not None: + idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim) + y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1) - # although scatter_ is inplace, this will not cause issues with autograd, as self - # is always constant zero and gradients w.r.t. src work as expected. - y_scattered = torch.zeros( - *broadcast_shape, - self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, - dtype=y.dtype, - device=y.device, - ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) + y_scattered = self._broadcast_and_scatter_along_last_dim( + y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx + ) # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( @@ -140,3 +178,103 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @staticmethod + def _broadcast_and_scatter_along_last_dim( + data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor + ) -> torch.Tensor: + """Broadcast scatter index and scatter into zero tensor. + + Parameters + ---------- + data_to_scatter + Data to be scattered at indices scatter_index + n_last_dim + Number of data points in last dimension + scatter_index + Indices describing where to scatter data + + Returns + ------- + Data scattered into tensor along scatter_index + """ + # scatter does not broadcast, so we need to manually broadcast the indices + broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1]) + idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1])) + + # although scatter_ is inplace, this will not cause issues with autograd, as self + # is always constant zero and gradients w.r.t. src work as expected. + data_scattered = torch.zeros( + *broadcast_shape, + n_last_dim, + dtype=data_to_scatter.dtype, + device=data_to_scatter.device, + ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) + + return data_scattered + + @property + def gram(self) -> 'CartesianSamplingGramOp': + """Return the Gram operator for this Cartesian Sampling Operator. + + Returns + ------- + Gram operator for this Cartesian Sampling Operator + """ + return CartesianSamplingGramOp(self) + + +class CartesianSamplingGramOp(LinearOperator): + """Gram operator for Cartesian Sampling Operator. + + The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp. + """ + + def __init__(self, sampling_op: CartesianSamplingOp): + """Initialize Cartesian Sampling Gram Operator class. + + This should not be used directly, but rather through the `gram` method of a + :class:`mrpro.operator.CartesianSamplingOp` object. + + Parameters + ---------- + sampling_op + The Cartesian Sampling Operator for which to create the Gram operator. + """ + super().__init__() + if sampling_op._needs_indexing: + ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx) + (mask,) = sampling_op.adjoint(*sampling_op.forward(ones)) + self.register_buffer('_mask', mask) + else: + self._mask: torch.Tensor | None = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the Gram operator. + + Parameters + ---------- + x + Input data + + Returns + ------- + Output data + """ + if self._mask is None: + return (x,) + return (x * self._mask,) + + def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint of the Gram operator. + + Parameters + ---------- + y + Input data + + Returns + ------- + Output data + """ + return self.forward(y) diff --git a/src/mrpro/operators/EndomorphOperator.py b/src/mrpro/operators/EndomorphOperator.py index 6608d3018..2d6a59439 100644 --- a/src/mrpro/operators/EndomorphOperator.py +++ b/src/mrpro/operators/EndomorphOperator.py @@ -4,9 +4,10 @@ from abc import abstractmethod from collections.abc import Callable -from typing import Any, ParamSpec, Protocol, TypeAlias, TypeVar, TypeVarTuple, cast, overload +from typing import TypeAlias, cast import torch +from typing_extensions import Any, ParamSpec, Protocol, TypeVar, TypeVarTuple, Unpack, overload import mrpro.operators from mrpro.operators.Operator import Operator @@ -172,7 +173,7 @@ def __call__( torch.Tensor, torch.Tensor, torch.Tensor, - *tuple[torch.Tensor, ...], + Unpack[tuple[torch.Tensor, ...]], ]: ... @overload @@ -191,7 +192,7 @@ def endomorph(f: F, /) -> _EndomorphCallable: return f -class EndomorphOperator(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]): +class EndomorphOperator(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): """Endomorph Operator. Endomorph Operators have N tensor inputs and exactly N outputs. @@ -211,9 +212,11 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: @overload def __matmul__(self, other: EndomorphOperator) -> EndomorphOperator: ... @overload - def __matmul__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: ... + def __matmul__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... - def __matmul__(self, other: Operator[*Tin, Tout] | EndomorphOperator) -> Operator[*Tin, Tout] | EndomorphOperator: + def __matmul__( + self, other: Operator[Unpack[Tin], Tout] | EndomorphOperator + ) -> Operator[Unpack[Tin], Tout] | EndomorphOperator: """Operator composition.""" if isinstance(other, mrpro.operators.MultiIdentityOp): return self @@ -224,8 +227,8 @@ def __matmul__(self, other: Operator[*Tin, Tout] | EndomorphOperator) -> Operato if isinstance(other, EndomorphOperator): return cast(EndomorphOperator, res) else: - return cast(Operator[*Tin, Tout], res) + return cast(Operator[Unpack[Tin], Tout], res) - def __rmatmul__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: + def __rmatmul__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: """Operator composition.""" - return other.__matmul__(cast(Operator[*Tin, tuple[*Tin]], self)) + return other.__matmul__(cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], self)) diff --git a/src/mrpro/operators/FastFourierOp.py b/src/mrpro/operators/FastFourierOp.py index 53f3f6eb4..4ffe7e3a6 100644 --- a/src/mrpro/operators/FastFourierOp.py +++ b/src/mrpro/operators/FastFourierOp.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: FFT of x """ y = torch.fft.fftshift( - torch.fft.fftn(torch.fft.ifftshift(*self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + torch.fft.fftn(torch.fft.ifftshift(*self._pad_op(x), dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim, ) return (y,) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index c7c16435a..a3e81aba7 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -1,21 +1,23 @@ """Fourier Operator.""" from collections.abc import Sequence -from typing import Self +from itertools import product import numpy as np import torch from torchkbnufft import KbNufft, KbNufftAdjoint +from typing_extensions import Self from mrpro.data._kdata.KData import KData from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator -class FourierOp(LinearOperator): +class FourierOp(LinearOperator, adjoint_as_backward=True): """Fourier Operator class.""" def __init__( @@ -67,12 +69,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +109,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +156,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op.forward(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +170,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +179,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +198,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil @@ -212,3 +224,163 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute) return (x,) + + @property + def gram(self) -> LinearOperator: + """Return the gram operator.""" + return FourierGramOp(self) + + +def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor: + """Enforce hermitian symmetry on the kernel. Returns only half of the kernel.""" + flipped = kernel.clone() + for d in range(-rank, 0): + flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d)) + kernel = (kernel + flipped.conj()) / 2 + last_len = kernel.shape[-1] + return kernel[..., : last_len // 2 + 1] + + +def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor: + """Calculate the convolution kernel for the NUFFT gram operator. + + Parameters + ---------- + weight + either ones or density compensation weights + trajectory + k-space trajectory + recon_shape + shape of the reconstructed image + + Returns + ------- + kernel + real valued convolution kernel for the NUFFT gram operator, already in Fourier space + """ + rank = trajectory.shape[-2] + if rank != len(recon_shape): + raise ValueError('Rank of trajectory and image size must match.') + # Instead of doing one adjoint nufft with double the recon size in all dimensions, + # we do two adjoint nuffts per dimensions, saving a lot of memory. + adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory) + + kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block + pad = [] + for s in kernel.shape[: -rank - 1 : -1]: + pad.extend([0, s]) + kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions + + for flips in list(product([1, -1], repeat=rank)): + if all(flip == 1 for flip in flips): + # top left ... block already processed before padding + continue + flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1) + kernel_part = adjnufft_ob(weight, flipped_trajectory) + slices = [] # which part of the kernel to is currently being processed + for dim, flip in zip(range(-rank, 0), flips, strict=True): + if flip > 0: # first half in the dimension + slices.append(slice(0, kernel_part.size(dim))) + else: # second half in the dimension + slices.append(slice(kernel_part.size(dim) + 1, None)) + kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip + + kernel[[..., *slices]] = kernel_part + + kernel = symmetrize(kernel, rank) + kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward') + kernel /= kernel.shape[-rank:].numel() + kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0))) + return kernel + + +class FourierGramOp(LinearOperator): + """Gram operator for the Fourier operator. + + Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator + `F.H@F. + + Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator + for the toeplitz NUFFT operator. + + Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for + the Cartesian FFT operator + + This Operator is only used internally and should not be used directly. + Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`. + """ + + _kernel: torch.Tensor | None + + def __init__(self, fourier_op: FourierOp) -> None: + """Initialize the gram operator. + + If density compensation weights are provided, they the operator + F.H@dcf@F is calculated. + + Parameters + ---------- + fourier_op + the Fourier operator to calculate the gram operator for + + """ + super().__init__() + if fourier_op._nufft_dims and fourier_op._omega is not None: + weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :]) + keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil + permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims + unpermute = np.argsort(permute) + weight = weight.permute(*permute) + weight_unflattend_shape = weight.shape + weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + weight = weight + 0j + omega = fourier_op._omega.permute(*permute) + omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size) + kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :]) + kernel = kernel.permute(*unpermute) + fft = FastFourierOp( + dim=fourier_op._nufft_dims, + encoding_matrix=[2 * s for s in fourier_op._nufft_im_size], + recon_matrix=fourier_op._nufft_im_size, + ) + self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft + else: + self.nufft_gram = None + + if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None: + self.fast_fourier_gram: None | LinearOperator = ( + fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op + ) + else: + self.fast_fourier_gram = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, z, y, x) + """ + if self.nufft_gram is not None: + (x,) = self.nufft_gram(x) + + if self.fast_fourier_gram is not None: + (x,) = self.fast_fourier_gram(x) + return (x,) + + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, k2, k1, k0) + """ + return self.forward(x) + + @property + def H(self) -> Self: # noqa: N802 + """Adjoint operator of the gram operator.""" + return self diff --git a/src/mrpro/operators/Functional.py b/src/mrpro/operators/Functional.py index a55603818..6dbce2b6e 100644 --- a/src/mrpro/operators/Functional.py +++ b/src/mrpro/operators/Functional.py @@ -50,8 +50,8 @@ class ElementaryFunctional(Functional): def __init__( self, - weight: torch.Tensor | complex = 1.0, target: torch.Tensor | None | complex = None, + weight: torch.Tensor | complex = 1.0, dim: int | Sequence[int] | None = None, divide_by_n: bool = False, keepdim: bool = False, @@ -64,10 +64,10 @@ def __init__( Parameters ---------- - weight - weight parameter (see above) target target element - often data tensor (see above) + weight + weight parameter (see above) dim dimension(s) over which functional is reduced. All other dimensions of `weight ( x - target)` will be treated as batch dimensions. diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index 4e23e81b0..029089f5d 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -6,9 +6,10 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from functools import reduce -from typing import Any, cast, no_type_check, overload +from typing import cast, no_type_check import torch +from typing_extensions import Any, Unpack, overload import mrpro.operators from mrpro.operators.Operator import ( @@ -101,7 +102,7 @@ def operator_norm( max_iterations: int = 20, relative_tolerance: float = 1e-4, absolute_tolerance: float = 1e-5, - callback: Callable | None = None, + callback: Callable[[torch.Tensor], None] | None = None, ) -> torch.Tensor: """Power iteration for computing the operator norm of the linear operator. @@ -162,7 +163,7 @@ def operator_norm( # operator norm is a strictly positive number. This ensures that the first time the # change between the old and the new estimate of the operator norm is non-zero and # thus prevents the loop from exiting despite a non-correct estimate. - op_norm_old = torch.zeros(*tuple([1 for _ in range(vector.ndim)])) + op_norm_old = torch.zeros(*tuple([1 for _ in range(vector.ndim)]), device=vector.device) dim = tuple(dim) if dim is not None else dim for _ in range(max_iterations): @@ -194,11 +195,13 @@ def operator_norm( def __matmul__(self, other: LinearOperator) -> LinearOperator: ... @overload - def __matmul__(self, other: Operator[*Tin2, tuple[torch.Tensor,]]) -> Operator[*Tin2, tuple[torch.Tensor,]]: ... + def __matmul__( + self, other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] + ) -> Operator[Unpack[Tin2], tuple[torch.Tensor,]]: ... def __matmul__( - self, other: Operator[*Tin2, tuple[torch.Tensor,]] | LinearOperator - ) -> Operator[*Tin2, tuple[torch.Tensor,]] | LinearOperator: + self, other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] | LinearOperator + ) -> Operator[Unpack[Tin2], tuple[torch.Tensor,]] | LinearOperator: """Operator composition. Returns lambda x: self(other(x)) @@ -213,7 +216,7 @@ def __matmul__( return LinearOperatorComposition(self, other) elif isinstance(other, Operator): # cast due to https://github.com/python/mypy/issues/16335 - return OperatorComposition(self, cast(Operator[*Tin2, tuple[torch.Tensor,]], other)) + return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[torch.Tensor,]], other)) return NotImplemented # type: ignore[unreachable] def __radd__(self, other: torch.Tensor) -> LinearOperator: @@ -291,6 +294,20 @@ def __rmul__(self, other: torch.Tensor | complex) -> LinearOperator: else: return NotImplemented # type: ignore[unreachable] + def __and__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Vertical stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self], [other]] + return mrpro.operators.LinearOperatorMatrix(operators) + + def __or__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Horizontal stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self, other]] + return mrpro.operators.LinearOperatorMatrix(operators) + @property def gram(self) -> LinearOperator: """Gram operator. diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py new file mode 100644 index 000000000..ab0673398 --- /dev/null +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -0,0 +1,363 @@ +"""Linear Operator Matrix class.""" + +from __future__ import annotations + +import operator +from collections.abc import Callable, Iterator, Sequence +from functools import reduce +from types import EllipsisType +from typing import cast + +import torch +from typing_extensions import Self, Unpack + +from mrpro.operators.LinearOperator import LinearOperator, LinearOperatorSum +from mrpro.operators.Operator import Operator +from mrpro.operators.ZeroOp import ZeroOp + +_SingleIdxType = int | slice | EllipsisType | Sequence[int] +_IdxType = _SingleIdxType | tuple[_SingleIdxType, _SingleIdxType] + + +class LinearOperatorMatrix(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): + r"""Matrix of Linear Operators. + + A matrix of Linear Operators, where each element is a Linear Operator. + + This matrix can be applied to a sequence of tensors, where the number of tensors should match + the number of columns of the matrix. The output will be a sequence of tensors, where the number + of tensors will match the number of rows of the matrix. + The i-th output tensor is calculated as + :math:`\sum_j \text{operators}[i][j](x[j])` where :math:`\text{operators}[i][j]` is the linear operator + in the i-th row and j-th column and :math:`x[j]` is the j-th input tensor. + + The matrix can be indexed and sliced like a regular matrix to get submatrices. + If indexing returns a single element, it is returned as a Linear Operator. + + Basic arithmetic operations are supported with Linear Operators and Tensors. + + """ + + _operators: list[list[LinearOperator]] + + def __init__(self, operators: Sequence[Sequence[LinearOperator]]): + """Initialize Linear Operator Matrix. + + Create a matrix of LinearOperators from a sequence of rows, where each row is a sequence + of LinearOperators that represent the columns of the matrix. + + Parameters + ---------- + operators + A sequence of rows, which are sequences of Linear Operators. + """ + if not all(isinstance(op, LinearOperator) for row in operators for op in row): + raise ValueError('All elements should be LinearOperators.') + if not all(len(row) == len(operators[0]) for row in operators): + raise ValueError('All rows should have the same length.') + super().__init__() + self._operators = cast( # cast because ModuleList is not recognized as a list + list[list[LinearOperator]], torch.nn.ModuleList(torch.nn.ModuleList(row) for row in operators) + ) + self._shape = (len(operators), len(operators[0]) if operators else 0) + + @property + def shape(self) -> tuple[int, int]: + """Shape of the Operator Matrix (rows, columns).""" + return self._shape + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has columns. + + Returns + ------- + Output tensors. The same number of tensors as the operator has rows. + """ + if len(x) != self.shape[1]: + raise ValueError('Input should be the same number of tensors as the LinearOperatorMatrix has columns.') + return tuple( + reduce(operator.add, (op(xi)[0] for op, xi in zip(row, x, strict=True))) for row in self._operators + ) + + def __getitem__(self, idx: _IdxType) -> Self | LinearOperator: + """Index the Operator Matrix. + + Parameters + ---------- + idx + Index or slice to select rows and columns. + + Returns + ------- + Subset LinearOperatorMatrix or Linear Operator. + """ + idxs: tuple[_SingleIdxType, _SingleIdxType] = idx if isinstance(idx, tuple) else (idx, slice(None)) + if len(idxs) > 2: + raise IndexError('Too many indices for LinearOperatorMatrix') + + def _to_numeric_index(idx: slice | int | Sequence[int] | EllipsisType, length: int) -> Sequence[int]: + """Convert index to a sequence of integers or raise an error.""" + if isinstance(idx, slice): + if (idx.start is not None and (idx.start < -length or idx.start >= length)) or ( + idx.stop is not None and (idx.stop < -length or idx.stop > length) + ): + raise IndexError('Index out of range') + return range(*idx.indices(length)) + if isinstance(idx, int): + if idx < -length or idx >= length: + raise IndexError('Index out of range') + return (idx,) + if idx is Ellipsis: + return range(length) + if isinstance(idx, Sequence): + if min(idx) < -length or max(idx) >= length: + raise IndexError('Index out of range') + return idx + else: + raise IndexError('Invalid index type') + + row_numbers = _to_numeric_index(idxs[0], self._shape[0]) + col_numbers = _to_numeric_index(idxs[1], self._shape[1]) + + sliced_operators = [ + [row[col_number] for col_number in col_numbers] + for row in [self._operators[row_number] for row_number in row_numbers] + ] + + # Return a single operator if only one row and column is selected + if len(row_numbers) == 1 and len(col_numbers) == 1: + return sliced_operators[0][0] + else: + return self.__class__(sliced_operators) + + def __iter__(self) -> Iterator[Sequence[LinearOperator]]: + """Iterate over the rows of the Operator Matrix.""" + return iter(self._operators) + + def __repr__(self): + """Representation of the Operator Matrix.""" + return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' + + # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. + def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + """Addition.""" + operators: list[list[LinearOperator]] = [] + if isinstance(other, LinearOperatorMatrix): + if self.shape != other.shape: + raise ValueError('OperatorMatrix shapes do not match.') + for self_row, other_row in zip(self._operators, other._operators, strict=False): + operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) + elif isinstance(other, LinearOperator | torch.Tensor): + if not self.shape[0] == self.shape[1]: + raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') + for i, self_row in enumerate(self._operators): + operators.append([op + other if i == j else op for j, op in enumerate(self_row)]) + else: + return NotImplemented # type: ignore[unreachable] + return self.__class__(operators) + + def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + """Right addition.""" + return self.__add__(other) + + def __mul__(self, other: torch.Tensor | Sequence[torch.Tensor | complex] | complex) -> Self: + """LinearOperatorMatrix*Tensor multiplication. + + Example: ([A,B]*c)(x) = [A*c, B*c](x) = A(c*x) + B(c*x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[1] + elif len(other) != self.shape[1]: + raise ValueError('Other should have the same length as the operator has columns.') + else: + other_ = other + operators = [] + for row in self._operators: + operators.append([op * o for op, o in zip(row, other_, strict=True)]) + return self.__class__(operators) + + def __rmul__(self, other: torch.Tensor | Sequence[torch.Tensor] | complex) -> Self: + """Tensor*LinearOperatorMatrix multiplication. + + Example: (c*[A,B])(x) = [c*A, c*B](x) = c*A(x) + c*B(x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[0] + elif len(other) != self.shape[0]: + raise ValueError('Other should have the same length as the operator has rows.') + else: + other_ = other + operators = [] + for row, o in zip(self._operators, other_, strict=True): + operators.append([cast(LinearOperator, o * op) for op in row]) + return self.__class__(operators) + + def __matmul__(self, other: LinearOperator | Self) -> Self: # type: ignore[override] + """Composition of operators.""" + if isinstance(other, LinearOperator): + return self._binary_operation(other, '__matmul__') + elif isinstance(other, LinearOperatorMatrix): + if self.shape[1] != other.shape[0]: + raise ValueError('OperatorMatrix shapes do not match.') + new_operators = [] + for row in self._operators: + new_row = [] + for other_col in zip(*other._operators, strict=True): + elements = [s @ o for s, o in zip(row, other_col, strict=True)] + new_row.append(LinearOperatorSum(*elements)) + new_operators.append(new_row) + return self.__class__(new_operators) + return NotImplemented # type: ignore[unreachable] + + @property + def H(self) -> Self: # noqa N802 + """Adjoints of the operators.""" + return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the adjoint of the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has rows. + + Returns + ------- + Output tensors. The same number of tensors as the operator has columns. + """ + return self.H(*x) + + @classmethod + def from_diagonal(cls, *operators: LinearOperator): + """Create a diagonal LinearOperatorMatrix. + + Create a square LinearOperatorMatrix with the given Linear Operators on the diagonal, + resulting in a block-diagonal linear operator. + + Parameters + ---------- + operators + Sequence of Linear Operators to be placed on the diagonal. + """ + operator_matrix: list[list[LinearOperator]] = [ + [op if i == j else ZeroOp(False) for j in range(len(operators))] for i, op in enumerate(operators) + ] + return cls(operator_matrix) + + def operator_norm( + self, + *initial_value: torch.Tensor, + dim: Sequence[int] | None = None, + max_iterations: int = 20, + relative_tolerance: float = 1e-4, + absolute_tolerance: float = 1e-5, + callback: Callable[[torch.Tensor], None] | None = None, + ) -> torch.Tensor: + """Upper bound of operator norm of the Matrix. + + Uses the bounds :math:`||[A, B]^T|||<=sqrt(||A||^2 + ||B||^2)` and :math:`||[A, B]|||<=max(||A||,||B||)` + to estimate the operator norm of the matrix. + First, operator_norm is called on each element of the matrix. + Next, the norm is estimated for each column using the first bound. + Finally, the norm of the full matrix of linear operators is calculated using the second bound. + + Parameters + ---------- + initial_value + Initial value(s) for the power iteration, length should match the number of columns + of the operator matrix. + dim + Dimensions to calculate the operator norm over. Other dimensions are assumed to be + batch dimensions. None means all dimensions. + max_iterations + Maximum number of iterations used in the power iteration. + relative_tolerance + Relative tolerance for convergence. + absolute_tolerance + Absolute tolerance for convergence. + callback + Callback function to be called with the current estimate of the operator norm. + + + Returns + ------- + Estimated operator norm upper bound. + """ + + def _singlenorm(op: LinearOperator, initial_value: torch.Tensor): + return op.operator_norm( + initial_value, + dim=dim, + max_iterations=max_iterations, + relative_tolerance=relative_tolerance, + absolute_tolerance=absolute_tolerance, + callback=callback, + ) + + if len(initial_value) != self.shape[1]: + raise ValueError('Initial value should have the same length as the operator has columns.') + norms = torch.tensor( + [[_singlenorm(op, iv) for op, iv in zip(row, initial_value, strict=True)] for row in self._operators] + ) + norm = norms.square().sum(-2).sqrt().amax(-1).unsqueeze(-1) + return norm + + def __or__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Vertical stacking.""" + if isinstance(other, LinearOperator): + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking : cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[*self._operators[0], other]] + return self.__class__(operators) + else: + if (rows_self := self.shape[0]) != (rows_other := other.shape[0]): + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack matrices with {rows_self} and {rows_other}.' + ) + operators = [[*self_row, *other_row] for self_row, other_row in zip(self, other, strict=True)] + return self.__class__(operators) + + def __ror__(self, other: LinearOperator) -> Self: + """Vertical stacking.""" + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[other, *self._operators[0]]] + return self.__class__(operators) + + def __and__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Horizontal stacking.""" + if isinstance(other, LinearOperator): + if cols := self.shape[1] > 1: + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [*self._operators, [other]] + return self.__class__(operators) + else: + if (cols_self := self.shape[1]) != (cols_other := other.shape[1]): + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack matrices with {cols_self} and {cols_other} columns.' + ) + operators = [*self._operators, *other] + return self.__class__(operators) + + def __rand__(self, other: LinearOperator) -> Self: + """Horizontal stacking.""" + if cols := self.shape[1] > 1: + raise ValueError( + f'Shape mismatch in horizontal stacking: cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [[other], *self._operators] + return self.__class__(operators) diff --git a/src/mrpro/operators/MultiIdentityOp.py b/src/mrpro/operators/MultiIdentityOp.py index b2b6905eb..4bc2f74d4 100644 --- a/src/mrpro/operators/MultiIdentityOp.py +++ b/src/mrpro/operators/MultiIdentityOp.py @@ -1,8 +1,7 @@ """Identity Operator with arbitrary number of inputs.""" -from typing import Self - import torch +from typing_extensions import Self from mrpro.operators.EndomorphOperator import EndomorphOperator, endomorph diff --git a/src/mrpro/operators/Operator.py b/src/mrpro/operators/Operator.py index 3582ef976..d5a7ae83a 100644 --- a/src/mrpro/operators/Operator.py +++ b/src/mrpro/operators/Operator.py @@ -4,9 +4,10 @@ from abc import ABC, abstractmethod from functools import reduce -from typing import Generic, TypeVar, TypeVarTuple, cast, overload +from typing import Generic, TypeAlias, cast import torch +from typing_extensions import TypeVar, TypeVarTuple, Unpack, overload import mrpro.operators @@ -15,26 +16,30 @@ Tout = TypeVar('Tout', bound=tuple, covariant=True) # TODO: bind to torch.Tensors -class Operator(Generic[*Tin, Tout], ABC, torch.nn.Module): +class Operator(Generic[Unpack[Tin], Tout], ABC, torch.nn.Module): """The general Operator class.""" @abstractmethod - def forward(self, *args: *Tin) -> Tout: + def forward(self, *args: Unpack[Tin]) -> Tout: """Apply forward operator.""" ... - def __call__(self, *args: *Tin) -> Tout: + def __call__(self, *args: Unpack[Tin]) -> Tout: """Apply the forward operator.""" return super().__call__(*args) - def __matmul__(self: Operator[*Tin, Tout], other: Operator[*Tin2, tuple[*Tin]]) -> Operator[*Tin2, Tout]: + def __matmul__( + self: Operator[Unpack[Tin], Tout], other: Operator[Unpack[Tin2], tuple[Unpack[Tin]]] + ) -> Operator[Unpack[Tin2], Tout]: """Operator composition. Returns lambda x: self(other(x)) """ return OperatorComposition(self, other) - def __radd__(self: Operator[*Tin, tuple[*Tin]], other: torch.Tensor) -> Operator[*Tin, tuple[*Tin]]: + def __radd__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator right addition. Returns lambda x: other*x + self(x) @@ -42,38 +47,40 @@ def __radd__(self: Operator[*Tin, tuple[*Tin]], other: torch.Tensor) -> Operator return self + other @overload - def __add__(self, other: Operator[*Tin, Tout]) -> Operator[*Tin, Tout]: ... + def __add__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... @overload - def __add__(self: Operator[*Tin, tuple[*Tin]], other: torch.Tensor) -> Operator[*Tin, tuple[*Tin]]: ... + def __add__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... def __add__( - self, other: Operator[*Tin, Tout] | torch.Tensor | mrpro.operators.ZeroOp - ) -> Operator[*Tin, Tout] | Operator[*Tin, tuple[*Tin]]: + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | mrpro.operators.ZeroOp + ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator addition. Returns lambda x: self(x) + other(x) if other is a operator, lambda x: self(x) + other*x if other is a tensor """ if isinstance(other, torch.Tensor): - s = cast(Operator[*Tin, tuple[*Tin]], self) - o = cast(Operator[*Tin, tuple[*Tin]], mrpro.operators.MultiIdentityOp() * other) + s = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], self) + o = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], mrpro.operators.MultiIdentityOp() * other) return OperatorSum(s, o) elif isinstance(other, mrpro.operators.ZeroOp): return self elif isinstance(other, Operator): return OperatorSum( - cast(Operator[*Tin, Tout], other), self + cast(Operator[Unpack[Tin], Tout], other), self ) # cast due to https://github.com/python/mypy/issues/16335 return NotImplemented # type: ignore[unreachable] - def __mul__(self, other: torch.Tensor | complex) -> Operator[*Tin, Tout]: + def __mul__(self, other: torch.Tensor | complex) -> Operator[Unpack[Tin], Tout]: """Operator multiplication with tensor. Returns lambda x: self(x*other) """ return OperatorElementwiseProductLeft(self, other) - def __rmul__(self, other: torch.Tensor | complex) -> Operator[*Tin, Tout]: + def __rmul__(self, other: torch.Tensor | complex) -> Operator[Unpack[Tin], Tout]: """Operator multiplication with tensor. Returns lambda x: other*self(x) @@ -81,10 +88,10 @@ def __rmul__(self, other: torch.Tensor | complex) -> Operator[*Tin, Tout]: return OperatorElementwiseProductRight(self, other) -class OperatorComposition(Operator[*Tin2, Tout]): +class OperatorComposition(Operator[Unpack[Tin2], Tout]): """Operator composition.""" - def __init__(self, operator1: Operator[*Tin, Tout], operator2: Operator[*Tin2, tuple[*Tin]]): + def __init__(self, operator1: Operator[Unpack[Tin], Tout], operator2: Operator[Unpack[Tin2], tuple[Unpack[Tin]]]): """Operator composition initialization. Returns lambda x: operator1(operator2(x)) @@ -100,28 +107,28 @@ def __init__(self, operator1: Operator[*Tin, Tout], operator2: Operator[*Tin2, t self._operator1 = operator1 self._operator2 = operator2 - def forward(self, *args: *Tin2) -> Tout: + def forward(self, *args: Unpack[Tin2]) -> Tout: """Operator composition.""" return self._operator1(*self._operator2(*args)) -class OperatorSum(Operator[*Tin, Tout]): +class OperatorSum(Operator[Unpack[Tin], Tout]): """Operator addition.""" - _operators: list[Operator[*Tin, Tout]] # actually a torch.nn.ModuleList + _operators: list[Operator[Unpack[Tin], Tout]] # actually a torch.nn.ModuleList - def __init__(self, operator1: Operator[*Tin, Tout], /, *other_operators: Operator[*Tin, Tout]): + def __init__(self, operator1: Operator[Unpack[Tin], Tout], /, *other_operators: Operator[Unpack[Tin], Tout]): """Operator addition initialization.""" super().__init__() - ops: list[Operator[*Tin, Tout]] = [] + ops: list[Operator[Unpack[Tin], Tout]] = [] for op in (operator1, *other_operators): if isinstance(op, OperatorSum): ops.extend(op._operators) else: ops.append(op) - self._operators = cast(list[Operator[*Tin, Tout]], torch.nn.ModuleList(ops)) + self._operators = cast(list[Operator[Unpack[Tin], Tout]], torch.nn.ModuleList(ops)) - def forward(self, *args: *Tin) -> Tout: + def forward(self, *args: Unpack[Tin]) -> Tout: """Operator addition.""" def _add(a: tuple[torch.Tensor, ...], b: tuple[torch.Tensor, ...]) -> Tout: @@ -131,38 +138,41 @@ def _add(a: tuple[torch.Tensor, ...], b: tuple[torch.Tensor, ...]) -> Tout: return result -class OperatorElementwiseProductRight(Operator[*Tin, Tout]): +class OperatorElementwiseProductRight(Operator[Unpack[Tin], Tout]): """Operator elementwise right multiplication with a tensor. Performs Tensor*Operator(x) """ - def __init__(self, operator: Operator[*Tin, Tout], scalar: torch.Tensor | complex): + def __init__(self, operator: Operator[Unpack[Tin], Tout], scalar: torch.Tensor | complex): """Operator elementwise right multiplication initialization.""" super().__init__() self._operator = operator self._scalar = scalar - def forward(self, *args: *Tin) -> Tout: + def forward(self, *args: Unpack[Tin]) -> Tout: """Operator elementwise right multiplication.""" out = self._operator(*args) return cast(Tout, tuple(a * self._scalar for a in out)) -class OperatorElementwiseProductLeft(Operator[*Tin, Tout]): +class OperatorElementwiseProductLeft(Operator[Unpack[Tin], Tout]): """Operator elementwise left multiplication with a tensor. Performs Operator(x*Tensor) """ - def __init__(self, operator: Operator[*Tin, Tout], scalar: torch.Tensor | complex): + def __init__(self, operator: Operator[Unpack[Tin], Tout], scalar: torch.Tensor | complex): """Operator elementwise left multiplication initialization.""" super().__init__() self._operator = operator self._scalar = scalar - def forward(self, *args: *Tin) -> Tout: + def forward(self, *args: Unpack[Tin]) -> Tout: """Operator elementwise left multiplication.""" - multiplied = cast(tuple[*Tin], tuple(a * self._scalar for a in args if isinstance(a, torch.Tensor))) + multiplied = cast(tuple[Unpack[Tin]], tuple(a * self._scalar for a in args if isinstance(a, torch.Tensor))) out = self._operator(*multiplied) return cast(Tout, out) + + +OperatorType: TypeAlias = Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]] diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/ProximableFunctionalSeparableSum.py b/src/mrpro/operators/ProximableFunctionalSeparableSum.py index b4b96c4b5..bf96d09e7 100644 --- a/src/mrpro/operators/ProximableFunctionalSeparableSum.py +++ b/src/mrpro/operators/ProximableFunctionalSeparableSum.py @@ -5,15 +5,16 @@ import operator from collections.abc import Iterator from functools import reduce -from typing import Self, cast +from typing import cast import torch +from typing_extensions import Self, Unpack from mrpro.operators.Functional import ProximableFunctional from mrpro.operators.Operator import Operator -class ProximableFunctionalSeparableSum(Operator[*tuple[torch.Tensor, ...], tuple[torch.Tensor]]): +class ProximableFunctionalSeparableSum(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor]]): r"""Separabke Sum of Proximable Functionals. This is a separable sum of the functionals. The forward method returns the sum of the functionals diff --git a/src/mrpro/operators/SignalModel.py b/src/mrpro/operators/SignalModel.py index 9a081a406..66d573b85 100644 --- a/src/mrpro/operators/SignalModel.py +++ b/src/mrpro/operators/SignalModel.py @@ -1,8 +1,7 @@ """Signal Model Operators.""" -from typing import TypeVarTuple - import torch +from typing_extensions import TypeVarTuple, Unpack from mrpro.operators.Operator import Operator @@ -10,28 +9,5 @@ # SignalModel has multiple inputs and one output -class SignalModel(Operator[*Tin, tuple[torch.Tensor,]]): +class SignalModel(Operator[Unpack[Tin], tuple[torch.Tensor,]]): """Signal Model Operator.""" - - @staticmethod - def expand_tensor_dim(parameter: torch.Tensor, n_dim_to_expand: int) -> torch.Tensor: - """Extend the number of dimensions of a parameter tensor. - - This is commonly used in the `model.forward` to ensure the model parameters can be broadcasted to the - quantitative maps. E.g. a simple `InversionRecovery` model is evaluated for six different inversion times `ti`. - The inversion times are commonly the same for each voxel and hence `ti` could be of shape (6,) and the T1 and M0 - map could be of shape (100,100,100). To make sure `ti` can be broadcasted to the maps it needs to be extended to - the shape (6,1,1,1) which then yields a signal of shape (6,100,100,100). - - Parameters - ---------- - parameter - Parameter (e.g with shape (m,n)) - n_dim_to_expand - Number of dimensions to expand. If <= 0 then parameter is not changed. - - Returns - ------- - Parameter with expanded dimensions (e.g. (m,n,1,1) for n_dim_to_expand = 2) - """ - return parameter[..., *[None] * (n_dim_to_expand)] if n_dim_to_expand > 0 else parameter diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index c682e9460..0ed0ec568 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -12,8 +12,10 @@ from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp from mrpro.operators.Jacobian import Jacobian +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -24,8 +26,6 @@ from mrpro.operators.ZeroOp import ZeroOp __all__ = [ - "functionals", - "models", "CartesianSamplingOp", "ConstraintsOp", "DensityCompensationOp", @@ -38,17 +38,24 @@ "Functional", "GridSamplingOp", "IdentityOp", - "LinearOperator", "Jacobian", + "LinearOperator", + "LinearOperatorMatrix", "MagnitudeOp", + "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", + "ScaledFunctional", + "ScaledProximableFunctional", "SensitivityOp", "SignalModel", "SliceProjectionOp", "WaveletOp", "ZeroOp", "ZeroPadOp", -] + "functionals", + "models" +] \ No newline at end of file diff --git a/src/mrpro/operators/functionals/L1Norm.py b/src/mrpro/operators/functionals/L1Norm.py index 20380a9eb..29f7b753c 100644 --- a/src/mrpro/operators/functionals/L1Norm.py +++ b/src/mrpro/operators/functionals/L1Norm.py @@ -13,6 +13,8 @@ class L1Norm(ElementaryProximableFunctional): where W is a either a scalar or tensor that corresponds to a (block-) diagonal operator that is applied to the input. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions given at initialization. """ diff --git a/src/mrpro/operators/functionals/L1NormViewAsReal.py b/src/mrpro/operators/functionals/L1NormViewAsReal.py index d8aba9dac..e4227c70b 100644 --- a/src/mrpro/operators/functionals/L1NormViewAsReal.py +++ b/src/mrpro/operators/functionals/L1NormViewAsReal.py @@ -15,6 +15,8 @@ class L1NormViewAsReal(ElementaryProximableFunctional): If the parameter `weight` is real-valued, :math:`W_r` and :math:`W_i` are both set to `weight`. If it is complex-valued, :math:`W_r` and :math:`W_I` are set to the real and imaginary part, respectively. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions set at initialization. """ diff --git a/src/mrpro/operators/functionals/L2NormSquared.py b/src/mrpro/operators/functionals/L2NormSquared.py index 275e71bf9..c8d001f97 100644 --- a/src/mrpro/operators/functionals/L2NormSquared.py +++ b/src/mrpro/operators/functionals/L2NormSquared.py @@ -15,6 +15,8 @@ class L2NormSquared(ElementaryProximableFunctional): reconstruction when using a density-compensation function for k-space pre-conditioning, for masking of image data, or for spatially varying regularization weights. + In most cases, consider setting divide_by_n to true to be independent of input size. + Alternatively the functional :class:`mrpro.operators.functionals.MSE` can be used. The norm is computed along the dimensions given at initialization, all other dimensions are considered batch dimensions. """ diff --git a/src/mrpro/operators/functionals/MSE.py b/src/mrpro/operators/functionals/MSE.py new file mode 100644 index 000000000..8b67c4ed7 --- /dev/null +++ b/src/mrpro/operators/functionals/MSE.py @@ -0,0 +1,47 @@ +"""MSE-Functional.""" + +from collections.abc import Sequence + +import torch + +from mrpro.operators.functionals.L2NormSquared import L2NormSquared + + +class MSE(L2NormSquared): + r"""Functional class for the mean squared error.""" + + def __init__( + self, + target: torch.Tensor | None | complex = None, + weight: torch.Tensor | complex = 1.0, + dim: int | Sequence[int] | None = None, + divide_by_n: bool = True, + keepdim: bool = False, + ) -> None: + r"""Initialize MSE Functional. + + The MSE functional is given by + :math:`f: C^N -> [0, \infty), x -> 1/N \| W (x-b)\|_2^2`, + where :math:`W` is either a scalar or tensor that corresponds to a (block-) diagonal operator + that is applied to the input. The division by `N` can be disabled by setting `divide_by_n=False` + For more details also see :class:`mrpro.operators.functionals.L2NormSquared` + + Parameters + ---------- + target + target element - often data tensor (see above) + weight + weight parameter (see above) + dim + dimension(s) over which functional is reduced. + All other dimensions of `weight ( x - target)` will be treated as batch dimensions. + divide_by_n + if true, the result is scaled by the number of elements of the dimensions index by `dim` in + the tensor `weight ( x - target)`. If true, the functional is thus calculated as the mean, + else the sum. + keepdim + if true, the dimension(s) of the input indexed by dim are maintained and collapsed to singeltons, + else they are removed from the result. + + """ + super().__init__(weight=weight, target=target, dim=dim, divide_by_n=divide_by_n, keepdim=keepdim) diff --git a/src/mrpro/operators/functionals/MSEDataDiscrepancy.py b/src/mrpro/operators/functionals/MSEDataDiscrepancy.py deleted file mode 100644 index df44746ec..000000000 --- a/src/mrpro/operators/functionals/MSEDataDiscrepancy.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Mean squared error (MSE) data-discrepancy function.""" - -import torch -import torch.nn.functional as F # noqa: N812 - -from mrpro.operators.Operator import Operator - - -class MSEDataDiscrepancy(Operator[torch.Tensor, tuple[torch.Tensor]]): - """Mean Squared Error (MSE) loss function. - - This class implements the function :math:`1./N * || . - data ||_2^2`, where :math:`N` equals to the number of - elements of the tensor. - - Note: if one of data or input is complex-valued, we identify the space :math:`C^N` with :math:`R^{2N}` and - multiply the output by 2. By this, we achieve that for example :math:`MSE(1)` = :math:`MSE(1+1j*0)` = 1. - - Parameters - ---------- - data - observed data - """ - - def __init__(self, data: torch.Tensor): - """Initialize the MSE data-discrepancy operator.""" - super().__init__() - - # observed data - self.data = data - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: - """Calculate the MSE of the input. - - Parameters - ---------- - x - tensor whose MSE with respect to the data given at initialization should be calculated - - Returns - ------- - Mean Squared Error (MSE) of input and the data - """ - if torch.is_complex(x) or torch.is_complex(self.data): - # F.mse_loss is only implemented for real tensors - # Thus, we cast both to C and then to R^2 - # and undo the division by ten twice the number of elements in mse_loss - x_r2 = torch.view_as_real(x) if torch.is_complex(x) else torch.view_as_real(x + 1j * 0) - data_r2 = ( - torch.view_as_real(self.data) if torch.is_complex(self.data) else torch.view_as_real(self.data + 1j * 0) - ) - mse = F.mse_loss(x_r2, data_r2) * 2.0 - else: # both are real - mse = F.mse_loss(x, self.data) - return (mse,) diff --git a/src/mrpro/operators/functionals/__init__.py b/src/mrpro/operators/functionals/__init__.py index 2e44c1c9e..3fe3455d7 100644 --- a/src/mrpro/operators/functionals/__init__.py +++ b/src/mrpro/operators/functionals/__init__.py @@ -1,6 +1,6 @@ from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal from mrpro.operators.functionals.L2NormSquared import L2NormSquared -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional -__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSEDataDiscrepancy", "ZeroFunctional"] +__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSE", "ZeroFunctional"] diff --git a/src/mrpro/operators/models/InversionRecovery.py b/src/mrpro/operators/models/InversionRecovery.py index cbf02536c..eb691606c 100644 --- a/src/mrpro/operators/models/InversionRecovery.py +++ b/src/mrpro/operators/models/InversionRecovery.py @@ -3,6 +3,7 @@ import torch from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class InversionRecovery(SignalModel[torch.Tensor, torch.Tensor]): @@ -37,6 +38,6 @@ def forward(self, m0: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: ------- signal with shape (time ... other, coils, z, y, x) """ - ti = self.expand_tensor_dim(self.ti, m0.ndim - (self.ti.ndim - 1)) # -1 for time + ti = unsqueeze_right(self.ti, m0.ndim - (self.ti.ndim - 1)) # -1 for time signal = m0 * (1 - 2 * torch.exp(-(ti / t1))) return (signal,) diff --git a/src/mrpro/operators/models/MOLLI.py b/src/mrpro/operators/models/MOLLI.py index b85ffe945..9313e2b4a 100644 --- a/src/mrpro/operators/models/MOLLI.py +++ b/src/mrpro/operators/models/MOLLI.py @@ -3,6 +3,7 @@ import torch from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class MOLLI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor]): @@ -51,6 +52,6 @@ def forward(self, a: torch.Tensor, c: torch.Tensor, t1: torch.Tensor) -> tuple[t ------- signal with shape (time ... other, coils, z, y, x) """ - ti = self.expand_tensor_dim(self.ti, a.ndim - (self.ti.ndim - 1)) # -1 for time + ti = unsqueeze_right(self.ti, a.ndim - (self.ti.ndim - 1)) # -1 for time signal = a * (1 - c * torch.exp(ti / t1 * (1 - c))) return (signal,) diff --git a/src/mrpro/operators/models/MonoExponentialDecay.py b/src/mrpro/operators/models/MonoExponentialDecay.py index cf84221dd..a899d84c2 100644 --- a/src/mrpro/operators/models/MonoExponentialDecay.py +++ b/src/mrpro/operators/models/MonoExponentialDecay.py @@ -3,6 +3,7 @@ import torch from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class MonoExponentialDecay(SignalModel[torch.Tensor, torch.Tensor]): @@ -37,6 +38,6 @@ def forward(self, m0: torch.Tensor, decay_constant: torch.Tensor) -> tuple[torch ------- signal with shape (time ... other, coils, z, y, x) """ - decay_time = self.expand_tensor_dim(self.decay_time, m0.ndim - (self.decay_time.ndim - 1)) # -1 for time + decay_time = unsqueeze_right(self.decay_time, m0.ndim - (self.decay_time.ndim - 1)) # -1 for time signal = m0 * torch.exp(-(decay_time / decay_constant)) return (signal,) diff --git a/src/mrpro/operators/models/SaturationRecovery.py b/src/mrpro/operators/models/SaturationRecovery.py index d24c0f496..86ecb0750 100644 --- a/src/mrpro/operators/models/SaturationRecovery.py +++ b/src/mrpro/operators/models/SaturationRecovery.py @@ -3,6 +3,7 @@ import torch from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class SaturationRecovery(SignalModel[torch.Tensor, torch.Tensor]): @@ -37,6 +38,6 @@ def forward(self, m0: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: ------- signal with shape (time ... other, coils, z, y, x) """ - ti = self.expand_tensor_dim(self.ti, m0.ndim - (self.ti.ndim - 1)) # -1 for time + ti = unsqueeze_right(self.ti, m0.ndim - (self.ti.ndim - 1)) # -1 for time signal = m0 * (1 - torch.exp(-(ti / t1))) return (signal,) diff --git a/src/mrpro/operators/models/TransientSteadyStateWithPreparation.py b/src/mrpro/operators/models/TransientSteadyStateWithPreparation.py index 160a76e37..e08cefa0c 100644 --- a/src/mrpro/operators/models/TransientSteadyStateWithPreparation.py +++ b/src/mrpro/operators/models/TransientSteadyStateWithPreparation.py @@ -3,6 +3,7 @@ import torch from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class TransientSteadyStateWithPreparation(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor]): @@ -105,13 +106,13 @@ def forward(self, m0: torch.Tensor, t1: torch.Tensor, flip_angle: torch.Tensor) m0_ndim = m0.ndim # -1 for time - sampling_time = self.expand_tensor_dim(self.sampling_time, m0_ndim - (self.sampling_time.ndim - 1)) + sampling_time = unsqueeze_right(self.sampling_time, m0_ndim - (self.sampling_time.ndim - 1)) - repetition_time = self.expand_tensor_dim(self.repetition_time, m0_ndim - self.repetition_time.ndim) - m0_scaling_preparation = self.expand_tensor_dim( + repetition_time = unsqueeze_right(self.repetition_time, m0_ndim - self.repetition_time.ndim) + m0_scaling_preparation = unsqueeze_right( self.m0_scaling_preparation, m0_ndim - self.m0_scaling_preparation.ndim ) - delay_after_preparation = self.expand_tensor_dim( + delay_after_preparation = unsqueeze_right( self.delay_after_preparation, m0_ndim - self.delay_after_preparation.ndim ) diff --git a/src/mrpro/operators/models/WASABI.py b/src/mrpro/operators/models/WASABI.py index 66f48847f..207b63c0e 100644 --- a/src/mrpro/operators/models/WASABI.py +++ b/src/mrpro/operators/models/WASABI.py @@ -4,6 +4,7 @@ from torch import nn from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class WASABI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): @@ -80,7 +81,7 @@ def forward( ------- signal with shape (offsets ... other, coils, z, y, x) """ - offsets = self.expand_tensor_dim(self.offsets, b0_shift.ndim - (self.offsets.ndim - 1)) # -1 for offset + offsets = unsqueeze_right(self.offsets, b0_shift.ndim - (self.offsets.ndim - 1)) # -1 for offset delta_x = offsets - b0_shift b1 = self.b1_nom * relative_b1 diff --git a/src/mrpro/operators/models/WASABITI.py b/src/mrpro/operators/models/WASABITI.py index e5954ce7b..ee1e4ae31 100644 --- a/src/mrpro/operators/models/WASABITI.py +++ b/src/mrpro/operators/models/WASABITI.py @@ -4,6 +4,7 @@ from torch import nn from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right class WASABITI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor]): @@ -79,8 +80,8 @@ def forward(self, b0_shift: torch.Tensor, rb1: torch.Tensor, t1: torch.Tensor) - signal with shape (offsets ... other, coils, z, y, x) """ delta_ndim = b0_shift.ndim - (self.offsets.ndim - 1) # -1 for offset - offsets = self.expand_tensor_dim(self.offsets, delta_ndim) - trec = self.expand_tensor_dim(self.trec, delta_ndim) + offsets = unsqueeze_right(self.offsets, delta_ndim) + trec = unsqueeze_right(self.trec, delta_ndim) b1 = self.b1_nom * rb1 da = offsets - b0_shift diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index 0ec3cc517..629a560d3 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -5,4 +5,12 @@ from mrpro.operators.models.WASABITI import WASABITI from mrpro.operators.models.MonoExponentialDecay import MonoExponentialDecay from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation -__all__ = ["SaturationRecovery", "InversionRecovery", "MOLLI", "WASABI", "WASABITI", "MonoExponentialDecay", "TransientSteadyStateWithPreparation"] +__all__ = [ + "InversionRecovery", + "MOLLI", + "MonoExponentialDecay", + "SaturationRecovery", + "TransientSteadyStateWithPreparation", + "WASABI", + "WASABITI" +] \ No newline at end of file diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index 3c9046364..081d2e465 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -1,3 +1,3 @@ from mrpro.phantoms.EllipsePhantom import EllipsePhantom from mrpro.phantoms.phantom_elements import EllipseParameters -__all__ = ["EllipsePhantom", "EllipseParameters"] +__all__ = ["EllipseParameters", "EllipsePhantom"] \ No newline at end of file diff --git a/src/mrpro/phantoms/coils.py b/src/mrpro/phantoms/coils.py index e936900de..dd9c208fe 100644 --- a/src/mrpro/phantoms/coils.py +++ b/src/mrpro/phantoms/coils.py @@ -1,6 +1,5 @@ """Numerical coil simulations.""" -import numpy as np import torch from einops import repeat @@ -45,16 +44,16 @@ def birdcage_2d( y_co = repeat(y_co, 'y x -> coils y x', coils=1) c = repeat(torch.linspace(0, dim[0] - 1, dim[0]), 'coils -> coils y x', y=1, x=1) - coil_center_x = dim[2] * relative_radius * np.cos(c * (2 * torch.pi / dim[0])) - coil_center_y = dim[1] * relative_radius * np.sin(c * (2 * torch.pi / dim[0])) + coil_center_x = dim[2] * relative_radius * torch.cos(c * (2 * torch.pi / dim[0])) + coil_center_y = dim[1] * relative_radius * torch.sin(c * (2 * torch.pi / dim[0])) coil_phase = -c * (2 * torch.pi / dim[0]) rr = torch.sqrt((x_co - coil_center_x) ** 2 + (y_co - coil_center_y) ** 2) phi = torch.arctan2((x_co - coil_center_x), -(y_co - coil_center_y)) + coil_phase - sensitivities = (1 / rr) * np.exp(1j * phi) + sensitivities = (1 / rr) * torch.exp(1j * phi) if normalize_with_rss: - rss = torch.sqrt(torch.sum(torch.abs(sensitivities) ** 2, 0)) + rss = sensitivities.abs().square().sum(0).sqrt() # Normalize only where rss is > 0 sensitivities[:, rss > 0] /= rss[None, rss > 0] diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index be1596ca3..944100d54 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,8 +1,25 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing +import mrpro.utils.unit_conversion +from mrpro.utils.fill_range import fill_range_ from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop -from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -__all__ = ["slice_profiles", "typing", "smap", "remove_repeat", "zero_pad_or_crop", "modify_acq_info", "split_idx",] +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +import mrpro.utils.unit_conversion + +__all__ = [ + "broadcast_right", + "fill_range_", + "reduce_view", + "remove_repeat", + "slice_profiles", + "smap", + "split_idx", + "typing", + "unit_conversion", + "unsqueeze_left", + "unsqueeze_right", + "zero_pad_or_crop" +] \ No newline at end of file diff --git a/src/mrpro/utils/fill_range.py b/src/mrpro/utils/fill_range.py new file mode 100644 index 000000000..c064c63bd --- /dev/null +++ b/src/mrpro/utils/fill_range.py @@ -0,0 +1,24 @@ +"""Fill tensor in-place along a specified dimension with increasing integers.""" + +import torch + + +def fill_range_(tensor: torch.Tensor, dim: int) -> None: + """ + Fill tensor in-place along a specified dimension with increasing integers. + + Parameters + ---------- + tensor + The tensor to be modified in-place. + + dim + The dimension along which to fill with increasing values. + """ + if not -tensor.ndim <= dim < tensor.ndim: + raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.') + + dim = dim % tensor.ndim + shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)] + values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape) + tensor[:] = values.expand_as(tensor) diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py deleted file mode 100644 index d535e53c5..000000000 --- a/src/mrpro/utils/modify_acq_info.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Modify AcqInfo.""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import TYPE_CHECKING - -import torch - -if TYPE_CHECKING: - from mrpro.data.AcqInfo import AcqInfo - - -def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: - """Go through all fields of AcqInfo object and apply changes. - - Parameters - ---------- - fun_modify - Function which takes AcqInfo fields as input and returns modified AcqInfo field - acq_info - AcqInfo object - """ - # Apply function to all fields of acq_info - for field in dataclasses.fields(acq_info): - current = getattr(acq_info, field.name) - if isinstance(current, torch.Tensor): - setattr(acq_info, field.name, fun_modify(current)) - elif dataclasses.is_dataclass(current): - for subfield in dataclasses.fields(current): - subcurrent = getattr(current, subfield.name) - setattr(current, subfield.name, fun_modify(subcurrent)) - - return acq_info diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py new file mode 100644 index 000000000..31d495afd --- /dev/null +++ b/src/mrpro/utils/reshape.py @@ -0,0 +1,101 @@ +"""Tensor reshaping utilities.""" + +from collections.abc import Sequence + +import torch + + +def unsqueeze_right(x: torch.Tensor, n: int) -> torch.Tensor: + """Unsqueeze multiple times in the rightmost dimension. + + Example: + tensor with shape (1,2,3) and n=2 would result in tensor with shape (1,2,3,1,1) + + Parameters + ---------- + x + tensor to unsqueeze + n + number of times to unsqueeze + + Returns + ------- + unsqueezed tensor (view) + """ + return x.reshape(*x.shape, *(n * (1,))) + + +def unsqueeze_left(x: torch.Tensor, n: int) -> torch.Tensor: + """Unsqueze multiple times in the leftmost dimension. + + Example: + tensor with shape (1,2,3) and n=2 would result in tensor with shape (1,1,1,2,3) + + + Parameters + ---------- + x + tensor to unsqueeze + n + number of times to unsqueeze + + Returns + ------- + unsqueezed tensor (view) + """ + return x.reshape(*(n * (1,)), *x.shape) + + +def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Broadcasting on the right. + + Given multiple tensors, apply broadcasting with unsqueezed on the right. + First, tensors are unsqueezed on the right to the same number of dimensions. + Then, torch.broadcasting is used. + + Example: + tensors with shapes (1,2,3), (1,2), (2) + results in tensors with shape (2,2,3) + + Parameters + ---------- + x + tensors to broadcast + + Returns + ------- + broadcasted tensors (views) + """ + max_dim = max(el.ndim for el in x) + unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x)) + return unsqueezed + + +def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor: + """Reduce expanded dimensions in a view to singletons. + + Reduce either all or specific dimensions to a singleton if it + points to the same memory address. + This undoes expand. + + Parameters + ---------- + x + input tensor + dim + only reduce expanded dimensions in the specified dimensions. + If None, reduce all expanded dimensions. + """ + if dim is None: + dim_: Sequence[int] = range(x.ndim) + elif isinstance(dim, Sequence): + dim_ = [d % x.ndim for d in dim] + else: + dim_ = [dim % x.ndim] + + stride = x.stride() + newsize = [ + 1 if stride == 0 and d in dim_ else oldsize + for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) + ] + return torch.as_strided(x, newsize, stride) diff --git a/src/mrpro/utils/typing.py b/src/mrpro/utils/typing.py index 62adc69b4..f90e18dab 100644 --- a/src/mrpro/utils/typing.py +++ b/src/mrpro/utils/typing.py @@ -1,20 +1,31 @@ """Some type hints that are used in multiple places in the codebase but not part of mrpro's public API.""" -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, TypeAlias + +from typing_extensions import Any if TYPE_CHECKING: from types import EllipsisType - from typing import SupportsIndex, TypeAlias + from typing import TypeAlias import torch + from numpy import ndarray from torch._C import _NestedSequence as NestedSequence + from typing_extensions import SupportsIndex # This matches the torch.Tensor indexer typehint - _IndexerTypeInner: TypeAlias = None | bool | int | slice | EllipsisType | torch.Tensor - _SingleIndexerType: TypeAlias = SupportsIndex | _IndexerTypeInner | NestedSequence[_IndexerTypeInner] - IndexerType: TypeAlias = tuple[_SingleIndexerType, ...] | _SingleIndexerType + _TorchIndexerTypeInner: TypeAlias = None | bool | int | slice | EllipsisType | torch.Tensor + _SingleTorchIndexerType: TypeAlias = SupportsIndex | _TorchIndexerTypeInner | NestedSequence[_TorchIndexerTypeInner] + TorchIndexerType: TypeAlias = tuple[_SingleTorchIndexerType, ...] | _SingleTorchIndexerType + + # This matches the numpy.ndarray indexer typehint + _SingleNumpyIndexerType: TypeAlias = ndarray | SupportsIndex | None | slice | EllipsisType + NumpyIndexerType: TypeAlias = tuple[_SingleNumpyIndexerType, ...] | _SingleNumpyIndexerType + + else: - IndexerType: TypeAlias = Any + TorchIndexerType: TypeAlias = Any NestedSequence: TypeAlias = Any + NumpyIndexerType: TypeAlias = Any -__all__ = ['IndexerType', 'NestedSequence'] +__all__ = ['TorchIndexerType', 'NumpyIndexerType', 'NestedSequence'] diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py new file mode 100644 index 000000000..0115bed47 --- /dev/null +++ b/src/mrpro/utils/unit_conversion.py @@ -0,0 +1,94 @@ +"""Conversion between different units.""" + +from typing import TypeVar + +import numpy as np +import torch + +__all__ = [ + 'ms_to_s', + 's_to_ms', + 'mm_to_m', + 'm_to_mm', + 'deg_to_rad', + 'rad_to_deg', + 'lamor_frequency_to_magnetic_field', + 'magnetic_field_to_lamor_frequency', + 'GYROMAGNETIC_RATIO_PROTON', +] + +GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 +r"""The gyromagnetic ratio :math:`\frac{\gamma}{2\pi}` of 1H in H20 in Hz/T""" + +# Conversion functions for units +T = TypeVar('T', float, torch.Tensor) + + +def ms_to_s(ms: T) -> T: + """Convert ms to s.""" + return ms / 1000 + + +def s_to_ms(s: T) -> T: + """Convert s to ms.""" + return s * 1000 + + +def mm_to_m(mm: T) -> T: + """Convert mm to m.""" + return mm / 1000 + + +def m_to_mm(m: T) -> T: + """Convert m to mm.""" + return m * 1000 + + +def deg_to_rad(deg: T) -> T: + """Convert degree to radians.""" + if isinstance(deg, torch.Tensor): + return torch.deg2rad(deg) + return deg / 180.0 * np.pi + + +def rad_to_deg(deg: T) -> T: + """Convert radians to degree.""" + if isinstance(deg, torch.Tensor): + return torch.rad2deg(deg) + return deg * 180.0 / np.pi + + +def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: + """Convert the Lamor frequency [Hz] to the magntic field strength [T]. + + Parameters + ---------- + lamor_frequency + Lamor frequency [Hz] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Magnetic field strength [T] + """ + return lamor_frequency / gyromagnetic_ratio + + +def magnetic_field_to_lamor_frequency( + magnetic_field_strength: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON +) -> T: + """Convert the magntic field strength [T] to Lamor frequency [Hz]. + + Parameters + ---------- + magnetic_field_strength + Strength of the magnetic field [T] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Lamor frequency [Hz] + """ + return magnetic_field_strength * gyromagnetic_ratio diff --git a/src/mrpro/utils/vmf.py b/src/mrpro/utils/vmf.py new file mode 100644 index 000000000..7423592e3 --- /dev/null +++ b/src/mrpro/utils/vmf.py @@ -0,0 +1,64 @@ +"""Sampling from von Mises-Fisher distribution.""" + +# based on: https://github.com/jasonlaska/spherecluster/blob/701b0b1909088a56e353b363b2672580d4fe9d93/spherecluster/util.py +# http://stats.stackexchange.com/questions/156729/sampling-from-von-mises-fisher-distribution-in-python +# https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf +# http://www.stat.pitt.edu/sungkyu/software/randvonMisesFisher3.pdf + +from math import log, sqrt + +import torch + + +def sample_vmf(mu: torch.Tensor, kappa: float, n_samples: int) -> torch.Tensor: + """ + Generate samples from the von Mises-Fisher distribution. + + The von Mises-Fisher distribution is a circular normal distribution on the unit hypersphere. + + Parameters + ---------- + mu + Center of the distribution on the unit hypersphere. Shape: (..., dim) + kappa + Concentration parameter. + For small kappa, the distribution is close to uniform. + For large kappa, the distribution is close to a normal distribution with variance 1/kappa. + n_samples + Number of samples to generate. + + Returns + ------- + Samples from the von Mises-Fisher distribution. Shape: (num_samples, ..., dim) + """ + mu_ = mu.unsqueeze(0) if mu.dim() == 1 else mu + total_samples = n_samples * mu_[..., 0].numel() + mu_ = mu_.expand((n_samples, *mu_.shape)) + dim = mu_.shape[-1] + + b = (dim - 1) / (sqrt(4.0 * kappa**2 + (dim - 1) ** 2) + 2 * kappa) + x = (1.0 - b) / (1.0 + b) + c = kappa * x + (dim - 1) * log(1 - x**2) + + beta_dist = torch.distributions.Beta((dim - 1) / 2.0, (dim - 1) / 2.0) + uniform_dist = torch.distributions.Uniform(0, 1) + normal_dist = torch.distributions.Normal(0, 1) + + ws: list[torch.Tensor] = [] + + while sum(len(w) for w in ws) < total_samples: + # rejection sampling + z = beta_dist.sample(torch.Size((total_samples,))) + w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) + u = uniform_dist.sample(torch.Size((total_samples,))) + accepted = kappa * w + (dim - 1) * torch.log(1.0 - x * w) - c >= torch.log(u) + ws.append(w[accepted]) + weights = torch.cat(ws)[:total_samples].reshape(mu_.shape[:-1]) + + v = normal_dist.sample(mu_.shape) + orthogonal_vectors = v - (mu_ * v).sum(-1, keepdim=True) * mu_ / mu_.norm(dim=-1, keepdim=True) + orthonormal_vectors = orthogonal_vectors / orthogonal_vectors.norm(dim=-1, keepdim=True) + samples = orthonormal_vectors * (1.0 - weights**2).sqrt().unsqueeze(-1) + weights.unsqueeze(-1) * mu_ + if mu.dim() == 1: + samples = samples.squeeze(-2) + return samples diff --git a/src/mrpro/utils/zero_pad_or_crop.py b/src/mrpro/utils/zero_pad_or_crop.py index 42adda430..23fb39599 100644 --- a/src/mrpro/utils/zero_pad_or_crop.py +++ b/src/mrpro/utils/zero_pad_or_crop.py @@ -35,7 +35,7 @@ def zero_pad_or_crop( new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, ) -> torch.Tensor: - """Change shape of data by cropping or zero-padding. + """Change shape of data by center cropping or symmetric zero-padding. Parameters ---------- diff --git a/tests/_RandomGenerator.py b/tests/_RandomGenerator.py index 9248271e9..1d829b55a 100644 --- a/tests/_RandomGenerator.py +++ b/tests/_RandomGenerator.py @@ -27,11 +27,12 @@ class RandomGenerator: """ def __init__(self, seed): + """Initialize with a fixed seed.""" self.generator = torch.Generator().manual_seed(seed) @staticmethod def _clip_bounds(low, high, lowest, highest): - """Clips the bounds (low, high) to the given range (lowest, highest)""" + """Clips the bounds (low, high) to the given range (lowest, highest).""" if low > high: raise ValueError('low should be lower than high') low = max(low, lowest) @@ -54,24 +55,25 @@ def _dtype_bounds(dtype): return (info.min, info.max) def _randint(self, size, low, high, dtype=torch.int64) -> torch.Tensor: - """Generates uniform random integers in the range [low, high) with the - given dtype.""" + """Generate uniform random integers in [low, high) with given dtype.""" low, high = self._clip_bounds(low, high, *self._dtype_bounds(dtype)) return torch.randint(low, high, size, generator=self.generator, dtype=dtype) def _rand(self, size, low, high, dtype=torch.float32) -> torch.Tensor: - """Generates uniform random floats in the range [low, high) with the - given dtype.""" + """Generate uniform random floats in [low, high) with given dtype.""" low, high = self._clip_bounds(low, high, *self._dtype_bounds(dtype)) return (torch.rand(size, generator=self.generator, dtype=dtype) * (high - low)) + low def float32_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate float32 tensor of given size in [low, high).""" return self._rand(size, low, high, torch.float32) def float64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate float64 tensor of given size in [low, high).""" return self._rand(size, low, high, torch.float64) def complex64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate complex64 tensor of given size in [low, high).""" if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float32_tensor(size, low, high) @@ -79,6 +81,7 @@ def complex64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, h return (amp * torch.exp(1j * phase)).to(dtype=torch.complex64) def complex128_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate complex128 tensor of given size in [low, high).""" if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float64_tensor(size, low, high) @@ -86,15 +89,19 @@ def complex128_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, return (amp * torch.exp(1j * phase)).to(dtype=torch.complex128) def int8_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 7, high: int = 1 << 7) -> torch.Tensor: + """Generate int8 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int8) def int16_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 15, high: int = 1 << 15) -> torch.Tensor: + """Generate int16 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int16) def int32_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 31, high: int = 1 << 31) -> torch.Tensor: + """Generate int32 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int32) def int64_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 63, high: int = 1 << 63) -> torch.Tensor: + """Generate int64 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int64) # There is no uint32 in pytorch yet @@ -106,106 +113,133 @@ def int64_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 63, hi # return self._randint(size, low, high, dtype=torch.uint64) # noqa: ERA001 def uint8_tensor(self, size: Sequence[int] | int = (1,), low: int = 0, high: int = 1 << 8) -> torch.Tensor: + """Generate uint8 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.uint8) def bool(self) -> bool: + """Generate a random boolean value.""" return self.uint8(0, 1) == 1 def float32(self, low: float = 0.0, high: float = 1.0) -> float: + """Generate a float32 in [low, high).""" return self.float32_tensor((1,), low, high).item() def float64(self, low: float = 0.0, high: float = 1.0) -> float: + """Generate a float64 in [low, high).""" return self.float64_tensor((1,), low, high).item() def complex64(self, low: float = 0, high: float = 1.0) -> complex: + """Generate a complex64 in [low, high).""" return self.complex64_tensor((1,), low, high).item() def complex128(self, low: float = 0, high: float = 1.0) -> complex: + """Generate a complex128 in [low, high).""" return self.complex128_tensor((1,), low, high).item() def uint8(self, low: int = 0, high: int = 1 << 8) -> int: + """Generate a uint8 in [low, high).""" return int(self.uint8_tensor((1,), low, high).item()) def uint16(self, low: int = 0, high: int = 1 << 16) -> int: + """Generate a uint16 in [low, high).""" + if low < 0 or high > 1 << 16: + raise ValueError('Low must be positive and high must be <= 2^16') + # using int32 as it is the smallest that can hold 2^16 (no uint32 in pytorch) return int(self.int32_tensor((1,), low, high).item()) def uint32(self, low: int = 0, high: int = 1 << 32) -> int: - # using int64 to avoid overflow + """Generate a uint32 in [low, high).""" + if low < 0 or high > 1 << 32: + raise ValueError('Low must be positive and high must be <= 2^32') + # using int64 as it is the smallest that can hold 2^32 (no uint64 in pytorch) return int(self.int64_tensor((1,), low, high).item()) def int8(self, low: int = -1 << 7, high: int = 1 << 7 - 1) -> int: + """Generate an int8 in [low, high).""" return int(self.int8_tensor((1,), low, high).item()) def int16(self, low: int = -1 << 15, high: int = 1 << 15 - 1) -> int: + """Generate an int16 in [low, high).""" return int(self.int16_tensor((1,), low, high).item()) def int32(self, low: int = -1 << 31, high: int = 1 << 31 - 1) -> int: + """Generate an int32 in [low, high).""" return int(self.int32_tensor((1,), low, high).item()) def int64(self, low: int = -1 << 63, high: int = 1 << 63 - 1) -> int: + """Generate an int64 in [low, high).""" return int(self.int64_tensor((1,), low, high).item()) def uint64(self, low: int = 0, high: int = 1 << 64) -> int: - # pytorch does not support uint64, so we use int64 instead - # and then convert to uint64 + """Generate a uint64 in [low, high).""" + if low < 0 or high > 1 << 64: + raise ValueError('Low must be positive and high must be <= 2^64') + # no uint64 in pytorch. int64 would not be able to produce 2^64, + # so we need to shift the values from [-2^63, 2^63) to [0, 2^64) range_ = high - low - if low < 0: - raise ValueError('Low must be positive') - if range_ > 1 << 64: - raise ValueError('Range too large') new_low = -1 << 63 new_high = new_low + range_ value = self.int64(new_low, new_high) - new_low + low return value def float32_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[float, ...]: + """Generate a tuple of float32 of given size in [low, high).""" return tuple(self.float32_tensor((size,), low, high)) def float64_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[float, ...]: + """Generate a tuple of float64 of given size in [low, high).""" return tuple(self.float64_tensor((size,), low, high)) def complex64_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[complex, ...]: + """Generate a tuple of complex64 of given size in [low, high).""" return tuple(self.complex64_tensor((size,), low, high)) def complex128_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[complex, ...]: + """Generate a tuple of complex128 of given size in [low, high).""" return tuple(self.complex128_tensor((size,), low, high)) def uint8_tuple(self, size: int, low: int = 0, high: int = 1 << 8) -> tuple[int, ...]: + """Generate a tuple of uint8 of given size in [low, high).""" return tuple(self.uint8_tensor((size,), low, high)) def uint16_tuple(self, size: int, low: int = 0, high: int = 1 << 16) -> tuple[int, ...]: - # no uint16_tensor, so we use uint16 instead + """Generate a tuple of uint16 of given size in [low, high).""" return tuple([self.uint16(low, high) for _ in range(size)]) def uint32_tuple(self, size: int, low: int = 0, high: int = 1 << 32) -> tuple[int, ...]: - # no uint32_tensor, so we use uint32 instead + """Generate a tuple of uint32 of given size in [low, high).""" return tuple([self.uint32(low, high) for _ in range(size)]) def uint64_tuple(self, size: int, low: int = 0, high: int = 1 << 64) -> tuple[int, ...]: - # no uint64_tensor, so we use uint64 instead + """Generate a tuple of uint64 of given size in [low, high).""" return tuple([self.uint64(low, high) for _ in range(size)]) def int8_tuple(self, size: int, low: int = -1 << 7, high: int = 1 << 7) -> tuple[int, ...]: + """Generate a tuple of int8 of given size in [low, high).""" return tuple(self.int8_tensor((size,), low, high)) def int16_tuple(self, size: int, low: int = -1 << 15, high: int = 1 << 15) -> tuple[int, ...]: + """Generate a tuple of int16 of given size in [low, high).""" return tuple(self.int16_tensor((size,), low, high)) def int32_tuple(self, size: int, low: int = -1 << 31, high: int = 1 << 31) -> tuple[int, ...]: + """Generate a tuple of int32 of given size in [low, high).""" return tuple(self.int32_tensor((size,), low, high)) def int64_tuple(self, size: int, low: int = -1 << 63, high: int = 1 << 63) -> tuple[int, ...]: + """Generate a tuple of int64 of given size in [low, high).""" return tuple(self.int64_tensor((size,), low, high)) def ascii(self, size: int) -> str: + """Generate a random ASCII string of given size.""" return ''.join([chr(self.uint8(32, 127)) for _ in range(size)]) def rand_like(self, x: torch.Tensor, low=0.0, high=1.0) -> torch.Tensor: - """Generate a tensor with the same shape as x filled with uniform random numbers in [low , high).""" + """Generate tensor like x with uniform random numbers in [low, high).""" return self.rand_tensor(x.shape, x.dtype, low=low, high=high) def rand_tensor(self, shape: Sequence[int], dtype: torch.dtype, low: float, high: float) -> torch.Tensor: - """Generates a tensor with the given shape and dtype filled with uniform random numbers in [low , high).""" + """Generate tensor of given shape and dtype in [low, high).""" if dtype.is_complex: tensor = self.complex64_tensor(shape, low, high).to(dtype=dtype) elif dtype.is_floating_point: @@ -215,3 +249,7 @@ def rand_tensor(self, shape: Sequence[int], dtype: torch.dtype, low: float, high else: tensor = self._randint(shape, low, high, dtype) return tensor + + def randperm(self, n, *, dtype=torch.int64) -> torch.Tensor: + """Generate random permutation of integers from 0 to n-1.""" + return torch.randperm(n, generator=self.generator, dtype=dtype) diff --git a/tests/__init__.py b/tests/__init__.py index e05dc1d29..675fd5e26 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,2 @@ from ._RandomGenerator import RandomGenerator +from .helper import relative_image_difference, dotproduct_adjointness_test, operator_isometry_test, linear_operator_unitary_test, autodiff_test diff --git a/tests/algorithms/csm/test_inati.py b/tests/algorithms/csm/test_inati.py index 0e179b2e5..beaa2fa5d 100644 --- a/tests/algorithms/csm/test_inati.py +++ b/tests/algorithms/csm/test_inati.py @@ -3,8 +3,8 @@ import torch from mrpro.algorithms.csm import inati from mrpro.data import SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.conftest import multi_coil_image -from tests.helper import relative_image_difference def test_inati(ellipse_phantom, random_kheader): diff --git a/tests/algorithms/csm/test_walsh.py b/tests/algorithms/csm/test_walsh.py index 36eee7b60..e08dd6ce3 100644 --- a/tests/algorithms/csm/test_walsh.py +++ b/tests/algorithms/csm/test_walsh.py @@ -3,8 +3,8 @@ import torch from mrpro.algorithms.csm import walsh from mrpro.data import SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.conftest import multi_coil_image -from tests.helper import relative_image_difference def test_walsh(ellipse_phantom, random_kheader): diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 16c16e78b..4abda2a98 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -16,7 +16,8 @@ (1, 32, False), (4, 32, True), (4, 32, False), - ] + ], + ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'], ) def system(request): """Generate data for creating a system Hx=b with linear and self-adjoint @@ -145,3 +146,13 @@ def callback(cg_status: CGStatus) -> None: assert True cg(h_operator, right_hand_side, callback=callback) + + +def test_autograd(system): + """Test autograd through cg""" + h_operator, right_hand_side, _ = system + right_hand_side.requires_grad_(True) + with torch.autograd.detect_anomaly(): + result = cg(h_operator, right_hand_side, tolerance=0, max_iterations=5) + result.abs().sum().backward() + assert right_hand_side.grad is not None diff --git a/tests/conftest.py b/tests/conftest.py index e3f943462..30ae9c229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -45,9 +46,9 @@ def generate_random_acquisition_properties(generator: RandomGenerator): 'encoding_space_ref': generator.uint16(), 'sample_time_us': generator.float32(), 'position': generator.float32_tuple(3), - 'read_dir': generator.float32_tuple(3), - 'phase_dir': generator.float32_tuple(3), - 'slice_dir': generator.float32_tuple(3), + 'read_dir': (1, 0, 0), # read, phase and slice have to form rotation + 'phase_dir': (0, 1, 0), + 'slice_dir': (0, 0, 1), 'patient_table_position': generator.float32_tuple(3), 'idx': ismrmrd.EncodingCounters(**idx_properties), 'user_int': generator.uint32_tuple(8), @@ -233,178 +234,193 @@ def create_uniform_traj(nk, k_shape): return k -def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): +def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Create trajectory with random entries.""" random_generator = RandomGenerator(seed=0) k_list = [] - for spacing, nk in zip([sz, sy, sx], [nkz, nky, nkx], strict=True): - if spacing == 'nuf': - k = random_generator.float32_tensor(size=nk) - elif spacing == 'uf': + for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): + if spacing == 'non-uniform': + k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk) + elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) - elif spacing == 'z': + elif spacing == 'zero': k = torch.zeros(nk) k_list.append(k) trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ - # (0) 2d cart mri with 1 coil, no oversampling - ( - (1, 1, 1, 96, 128), # img shape - (1, 1, 1, 96, 128), # k shape - (1, 1, 1, 128), # kx - (1, 1, 96, 1), # ky - (1, 1, 1, 1), # kz - 'uf', # kx is uniform - 'uf', # ky is uniform - 'z', # zero so no Fourier transform is performed along that dimension - 'uf', # k0 is uniform - 'uf', # k1 is uniform - 'z', # k2 is singleton + ( # (0) 2d Cartesian single coil, no oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 96, 128), # k_shape + (1, 1, 1, 128), # nkx + (1, 1, 96, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (1) 2d cart mri with 1 coil, with oversampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 128, 192), - (1, 1, 1, 192), - (1, 1, 128, 1), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (1) 2d Cartesian single coil, with oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 128, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 128, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (2) 2d non-Cartesian mri with 2 coils - ( - (1, 2, 1, 96, 128), - (1, 2, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 1, 1), - 'nuf', # kx is non-uniform - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (2) 2d non-Cartesian mri with 2 coils + (1, 2, 1, 96, 128), # im_shape + (1, 2, 1, 16, 192), # k_shape + (1, 1, 16, 192), # nkx + (1, 1, 16, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (3) 2d cart mri with irregular sampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'z', - 'z', + ( # (3) 2d Cartesian with irregular sampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (4) 2d single shot spiral - ( - (1, 2, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'z', - 'z', + ( # (4) 2d single shot spiral + (1, 2, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (5) 3d nuFFT mri, 4 coils, 2 other - ( - (2, 4, 16, 32, 64), - (2, 4, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (5) 3d non-uniform, 4 coils, 2 other + (2, 4, 16, 32, 64), # im_shape + (2, 4, 16, 32, 64), # k_shape + (2, 16, 32, 64), # nkx + (2, 16, 32, 64), # nky + (2, 16, 32, 64), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (6) 2d nuFFT cine mri with 8 cardiac phases, 5 coils - ( - (8, 5, 1, 64, 64), - (8, 5, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (6) 2d non-uniform cine with 8 cardiac phases, 5 coils + (8, 5, 1, 64, 64), # im_shape + (8, 5, 1, 18, 128), # k_shape + (8, 1, 18, 128), # nkx + (8, 1, 18, 128), # nky + (8, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (7) 2d cart cine mri with 9 cardiac phases, 6 coils - ( - (9, 6, 1, 96, 128), - (9, 6, 1, 128, 192), - (9, 1, 1, 192), - (9, 1, 128, 1), - (9, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (7) 2d cartesian cine with 9 cardiac phases, 6 coils + (9, 6, 1, 96, 128), # im_shape + (9, 6, 1, 128, 192), # k_shape + (9, 1, 1, 192), # nkx + (9, 1, 128, 1), # nky + (9, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and nuFFT directions - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'uf', - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', + ( # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and non-uniform directions + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (9) radial phase encoding (RPE) , 8 coils with non-Cartesian sampling along readout - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (9) radial phase encoding (RPE), 8 coils with non-Cartesian sampling along readout + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and nuFFT directions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', - 'uf', + ( # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and non-uniform directions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 1, 18, 64), # nky + (5, 96, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'uniform', # type_k2 ), ], + ids=[ + '2d_cartesian_1_coil_no_oversampling', + '2d_cartesian_1_coil_with_oversampling', + '2d_non_cartesian_mri_2_coils', + '2d_cartesian_irregular_sampling', + '2d_single_shot_spiral', + '3d_nonuniform_4_coils_2_other', + '2d_nnonuniform_cine_mri_8_cardiac_phases_5_coils', + '2d_cartesian_cine_9_cardiac_phases_6_coils', + 'radial_phase_encoding_8_coils_with_oversampling', + 'radial_phase_encoding_8_coils_non_cartesian_sampling', + 'stack_of_stars_5_other_3_coil_with_oversampling', + ], ) diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 59c42be15..efeff6ed1 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -10,6 +10,8 @@ from mrpro.data import SpatialDimension from mrpro.phantoms import EllipsePhantom +from tests import RandomGenerator + ISMRMRD_TRAJECTORY_TYPE = ( 'cartesian', 'epi', @@ -67,6 +69,7 @@ def __init__( trajectory_type: Literal['cartesian', 'radial'] = 'cartesian', sampling_order: Literal['linear', 'low_high', 'high_low', 'random'] = 'linear', phantom: EllipsePhantom | None = None, + add_bodycoil_acquisitions: bool = False, n_separate_calibration_lines: int = 0, ): if not phantom: @@ -222,23 +225,32 @@ def __init__( acq.phase_dir[1] = 1.0 acq.slice_dir[2] = 1.0 - # Initialize an acquisition counter - counter = 0 + scan_counter = 0 # Write out a few noise scans for _ in range(32): noise = self.noise_level * torch.randn(self.n_coils, n_freq_encoding, dtype=torch.complex64) # here's where we would make the noise correlated - acq.scan_counter = counter + acq.scan_counter = scan_counter acq.clearAllFlags() acq.setFlag(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) acq.data[:] = noise.numpy() dataset.append_acquisition(acq) - counter += 1 # increment the scan counter + scan_counter += 1 + + # Add acquisitions obtained with a 2-element body coil (e.g. used for adjustment scans) + if add_bodycoil_acquisitions: + acq.resize(n_freq_encoding, 2, trajectory_dimensions=2) + for _ in range(8): + acq.scan_counter = scan_counter + acq.clearAllFlags() + acq.data[:] = torch.randn(2, n_freq_encoding, dtype=torch.complex64) + dataset.append_acquisition(acq) + scan_counter += 1 + acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) # Calibration lines if n_separate_calibration_lines > 0: - # we take calibration lines around the k-space center traj_ky_calibration, traj_kx_calibration, kpe_calibration = self._cartesian_trajectory( n_separate_calibration_lines, n_freq_encoding, @@ -253,7 +265,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe_calibration): # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -264,7 +276,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_calibration[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Loop over the repetitions, add noise and write to disk for rep in range(self.repetitions): @@ -275,7 +287,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe[rep]): if not self.flag_invalid_reps or rep == 0 or pe_idx < len(kpe[rep]) // 2: # fewer lines for rep > 0 # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -298,7 +310,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_with_noise[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Clean up dataset.close() @@ -337,7 +349,7 @@ def _cartesian_trajectory( # Different temporal orders of phase encoding points if sampling_order == 'random': - perm = torch.randperm(len(kpe)) + perm = RandomGenerator(13).randperm(len(kpe)) kpe = kpe[perm[: len(perm) // acceleration]] elif sampling_order == 'linear': kpe, _ = torch.sort(kpe) diff --git a/tests/data/test_csm_data.py b/tests/data/test_csm_data.py index 0246a2a07..bb759a35a 100644 --- a/tests/data/test_csm_data.py +++ b/tests/data/test_csm_data.py @@ -6,8 +6,8 @@ import torch from mrpro.data import CsmData, SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.test_walsh import multi_coil_image -from tests.helper import relative_image_difference def test_CsmData_is_frozen_dataclass(random_test_data, random_kheader): diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 822c63045..fa3e4ebd9 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -2,28 +2,30 @@ import pytest import torch -from einops import rearrange, repeat +from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension -from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp -from mrpro.utils import modify_acq_info, split_idx +from mrpro.utils import split_idx +from tests import relative_image_difference from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData -from tests.helper import relative_image_difference from tests.phantoms import EllipsePhantomTestData @pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" +def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set with bodycoil and surface coil data.""" ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' ismrmrd_kdata = IsmrmrdRawTestData( filename=ismrmrd_filename, noise_level=0.0, repetitions=3, phantom=ellipse_phantom.phantom, + add_bodycoil_acquisitions=True, ) return ismrmrd_kdata @@ -77,10 +79,11 @@ def consistently_shaped_kdata(request, random_kheader_shape): # Start with header kheader, n_other, n_coils, n_k2, n_k1, n_k0 = random_kheader_shape - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) # Create kdata with consistent shape kdata = generate_random_data(RandomGenerator(request.param['seed']), (n_other, n_coils, n_k2, n_k1, n_k0)) @@ -124,6 +127,34 @@ def test_KData_raise_wrong_trajectory_shape(ismrmrd_cart): _ = KData.from_file(ismrmrd_cart.filename, trajectory) +def test_KData_raise_warning_for_bodycoil(ismrmrd_cart_bodycoil_and_surface_coil): + """Mix of bodycoil and surface coil acquisitions leads to warning.""" + with pytest.raises(UserWarning, match='Acquisitions with different number'): + _ = KData.from_file(ismrmrd_cart_bodycoil_and_surface_coil.filename, DummyTrajectory()) + + +@pytest.mark.filterwarnings('ignore:Acquisitions with different number:UserWarning') +def test_KData_select_bodycoil_via_filter(ismrmrd_cart_bodycoil_and_surface_coil): + """Bodycoil can be selected via a custom acquisition filter.""" + # This is the recommended way of selecting the body coil (i.e. 2 receiver elements) + kdata = KData.from_file( + ismrmrd_cart_bodycoil_and_surface_coil.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + assert kdata.data.shape[-4] == 2 + + +def test_KData_raise_wrong_coil_number(ismrmrd_cart): + """Wrong number of coils leads to empty acquisitions.""" + with pytest.raises(ValueError, match='No acquisitions meeting the given filter criteria were found'): + _ = KData.from_file( + ismrmrd_cart.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + + def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): """Multiple repetitions with different number of phase encoding lines.""" with pytest.warns(UserWarning, match=r'different number'): @@ -162,7 +193,7 @@ def test_KData_kspace(ismrmrd_cart): assert relative_image_difference(reconstructed_img[0, 0, 0, ...], ismrmrd_cart.img_ref) <= 0.05 -@pytest.mark.parametrize(('field', 'value'), [('b0', 11.3), ('tr', torch.tensor([24.3]))]) +@pytest.mark.parametrize(('field', 'value'), [('lamor_frequency_proton', 42.88 * 1e6), ('tr', torch.tensor([24.3]))]) def test_KData_modify_header(ismrmrd_cart, field, value): """Overwrite some parameters in the header.""" parameter_dict = {field: value} @@ -469,3 +500,77 @@ def test_KData_remove_readout_os(monkeypatch, random_kheader): # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. assert relative_image_difference(torch.abs(img_recon), img_tensor[:, 0, ...]) <= 0.05 + + +def test_modify_acq_info(random_kheader_shape): + """Test the modification of the acquisition info.""" + # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] + kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape + + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) + + # Verify shape + assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) + assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) + + +def test_KData_compress_coils(ismrmrd_cart): + """Test coil combination does not alter image content (much).""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata = kdata.compress_coils(n_compressed_coils=4) + ff_op = FastFourierOp(dim=(-1, -2)) + (reconstructed_img,) = ff_op.adjoint(kdata.data) + + # Image content of each coil is the same. Therefore we only compare one coil image but we need to normalize. + reconstructed_img = reconstructed_img[0, 0, 0, ...].abs() + reconstructed_img /= reconstructed_img.max() + ref_img = ismrmrd_cart.img_ref[0, 0, 0, ...].abs() + ref_img /= ref_img.max() + + assert relative_image_difference(reconstructed_img, ref_img) <= 0.1 + + +@pytest.mark.parametrize( + ('batch_dims', 'joint_dims'), + [ + (None, ...), + ((0,), ...), + ((-2, -1), ...), + (None, (-1, -2, -3)), + (None, (0, -1, -2, -3)), + ], + ids=[ + 'single_compression', + 'batching_along_dim0', + 'batching_along_dim-2_and_dim-1', + 'single_compression_for_last_3_dims', + 'single_compression_for_last_3_and_first_dims', + ], +) +def test_KData_compress_coils_diff_batch_joint_dims(consistently_shaped_kdata, batch_dims, joint_dims): + """Test that all of these options work and yield the same shape.""" + n_compressed_coils = 4 + orig_kdata_shape = consistently_shaped_kdata.data.shape + kdata = consistently_shaped_kdata.compress_coils(n_compressed_coils, batch_dims, joint_dims) + assert kdata.data.shape == (*orig_kdata_shape[:-4], n_compressed_coils, *orig_kdata_shape[-3:]) + + +def test_KData_compress_coils_error_both_batch_and_joint(consistently_shaped_kdata): + """Test if error is raised if both batch_dims and joint_dims is defined.""" + with pytest.raises(ValueError, match='Either batch_dims or joint_dims'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, batch_dims=(0,), joint_dims=(0,)) + + +def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata): + """Test if error is raised if coil_dim is in batch_dims or joint_dims.""" + with pytest.raises(ValueError, match='Coil dimension must not'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, batch_dims=(-4,)) + + with pytest.raises(ValueError, match='Coil dimension must not'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,)) diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 4d06d73ff..c92e12b2a 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -1,11 +1,11 @@ """Tests the MoveDataMixin class.""" from dataclasses import dataclass, field -from typing import Any import pytest import torch from mrpro.data import MoveDataMixin +from typing_extensions import Any class SharedModule(torch.nn.Module): @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,42 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply_(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + new = data.apply(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2) + torch.testing.assert_close(new.floattensor, original.floattensor * 2) + torch.testing.assert_close(new.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' + assert new is not data, 'new object should be different from the original' diff --git a/tests/utils/test_rotation.py b/tests/data/test_rotation.py similarity index 85% rename from tests/utils/test_rotation.py rename to tests/data/test_rotation.py index 5f9831b64..035a2d6c6 100644 --- a/tests/utils/test_rotation.py +++ b/tests/data/test_rotation.py @@ -535,7 +535,7 @@ def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: assert torch.all(rms < rms_max) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -555,7 +555,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-15, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -576,7 +576,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-16, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -598,7 +598,7 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): torch.testing.assert_close(mat_expected, mat_estimated) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -924,12 +924,12 @@ def test_align_vectors_no_noise(): def test_align_vectors_improper_rotation(): """Test for scipy issue #10444""" - x = torch.tensor([[0.89299824, -0.44372674, 0.0752378], [0.60221789, -0.47564102, -0.6411702]]) - y = torch.tensor([[0.02386536, -0.82176463, 0.5693271], [-0.27654929, -0.95191427, -0.1318321]]) + x = torch.tensor([[0.89299824, -0.44372674, 0.0752378], [0.60221789, -0.47564102, -0.6411702]]).double() + y = torch.tensor([[0.02386536, -0.82176463, 0.5693271], [-0.27654929, -0.95191427, -0.1318321]]).double() est, rssd = Rotation.align_vectors(x, y) - torch.testing.assert_close(x, est(y), atol=1e-6, rtol=1e-4) - assert math.isclose(rssd, 0.0, abs_tol=1e-6, rel_tol=1e-4) + torch.testing.assert_close(x, est(y), atol=1e-7, rtol=0) + torch.testing.assert_close(rssd, torch.tensor(0.0, dtype=torch.float64), atol=1e-7, rtol=0) def test_align_vectors_rssd_sensitivity(): @@ -981,26 +981,23 @@ def test_align_vectors_noise(): # Check error bounds using covariance matrix cov *= sigma torch.testing.assert_close(torch.diag(cov), torch.zeros(3), atol=tolerance, rtol=0) - torch.testing.assert_close(torch.sum((noisy_result - est(vectors)) ** 2) ** 0.5, torch.tensor(rssd)) + torch.testing.assert_close(torch.sum((noisy_result - est(vectors)) ** 2) ** 0.5, rssd) def test_align_vectors_invalid_input(): - with pytest.raises(ValueError, match='Expected input `a` to have shape'): + with pytest.raises(ValueError, match='Expected inputs to have same shapes'): Rotation.align_vectors([1, 2, 3, 4], [1, 2, 3]) - with pytest.raises(ValueError, match='Expected input `b` to have shape'): - Rotation.align_vectors([1, 2, 3], [1, 2, 3, 4]) - - with pytest.raises(ValueError, match='Expected inputs `a` and `b` ' 'to have same shapes'): + with pytest.raises(ValueError, match='Expected inputs to have same shapes'): Rotation.align_vectors([[1, 2, 3], [4, 5, 6]], [[1, 2, 3]]) - with pytest.raises(ValueError, match='Expected `weights` to be 1 dimensional'): - Rotation.align_vectors([[1, 2, 3]], [[1, 2, 3]], weights=[[1]]) + with pytest.raises(ValueError, match='Expected inputs to have shape'): + Rotation.align_vectors([1, 2, 3, 4], [1, 2, 3, 4]) - with pytest.raises(ValueError, match='Expected `weights` to have number of values'): + with pytest.raises(ValueError, match='Invalid weights: expected shape'): Rotation.align_vectors([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], weights=[1, 2, 3]) - with pytest.raises(ValueError, match='`weights` may not contain negative values'): + with pytest.raises(ValueError, match='Invalid weights: expected shape'): Rotation.align_vectors([[1, 2, 3]], [[1, 2, 3]], weights=[-1]) with pytest.raises(ValueError, match='Only one infinite weight is allowed'): @@ -1083,10 +1080,10 @@ def test_align_vectors_near_inf(): def test_align_vectors_parallel(): atol = 1e-6 + a = [[1, 0, 0], [0, 1, 0]] b = [[0, 1, 0], [0, 1, 0]] m_expected = torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]).float() - r, _ = Rotation.align_vectors(a, b, weights=[torch.inf, 1]) torch.testing.assert_close(r.as_matrix(), m_expected, atol=atol, rtol=0) @@ -1337,7 +1334,8 @@ def test_mean_invalid_weights(): def test_repr(): """Test string representation""" assert repr(Rotation.identity(None)) == 'Rotation([[0.0, 0.0, 0.0, 1.0]])' - assert repr(Rotation.identity(1)) == '(1,)-Batched Rotation()' + assert repr(Rotation.identity(1)) == '(1,)-batched Rotation()' + assert repr(Rotation.identity(1).reflect()) == '(1,)-batched improper Rotation()' def test_quaternion_properties_single(): @@ -1381,3 +1379,209 @@ def test_quaternion_properties_batch(): def test_axis_order_zyx(): """Check that the axis order is set to zyx""" assert AXIS_ORDER == 'zyx' + + +def test_from_to_directions(): + """Test that from_directions and as_directions are inverse operations""" + one = torch.ones(1, 2, 3, 4) + + # must be a rotation + b1 = SpatialDimension(one * (0.8146), one * (0.4707), one * (-0.3388)) + b2 = SpatialDimension(one * (-0.4432), one * (0.8820), one * (0.1599)) + b3 = SpatialDimension(one * (-0.3741), one * (-0.0199), one * (-0.9272)) + + r = Rotation.from_directions(b1, b2, b3) + torch.testing.assert_close(b1.zyx, r.as_directions()[0].zyx, atol=1e-4, rtol=0) + torch.testing.assert_close(b2.zyx, r.as_directions()[1].zyx, atol=1e-4, rtol=0) + torch.testing.assert_close(b3.zyx, r.as_directions()[2].zyx, atol=1e-4, rtol=0) + + +def test_as_directions(): + """Test conversion to basis vectors""" + r = Rotation.random(10, random_state=0) + matrix = r.as_matrix() + directions = r.as_directions() + for col, basis in enumerate(directions): + for row, axis in enumerate(AXIS_ORDER): + expected = matrix[:, row, col] + actual = getattr(basis, axis) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=0) + + +def test_random_improper(): + """Test improper rotations""" + r = Rotation.random(10, random_state=0, improper=True) + matrix = r.as_matrix() + det = torch.linalg.det(matrix) + torch.testing.assert_close(det, -torch.ones(10)) + + +def test_reflect(): + """Test improper rotations""" + r = Rotation.random(None, random_state=0) + r2 = r.reflect() + r3 = r2.reflect() + det = torch.linalg.det(r2.as_matrix()) + torch.testing.assert_close(det, torch.tensor(-1.0)) + torch.testing.assert_close(r.as_matrix(), r3.as_matrix()) + + +def test_invert_axes(): + """Test inversion of axes""" + r = Rotation.random(None, random_state=0) + r2 = r.invert_axes() + r3 = r2.invert_axes() + det = torch.linalg.det(r2.as_matrix()) + torch.testing.assert_close(det, torch.tensor(-1.0)) + torch.testing.assert_close(r.as_matrix(), r3.as_matrix()) + torch.testing.assert_close(r.as_matrix(), -r2.as_matrix()) + + +def test_improper_quat_inversion(): + """Test improper quaternions with inversion""" + r = Rotation.random(10, random_state=0, improper='random') + q, inv = r.as_quat(improper='inversion') + assert torch.equal(r.is_improper, inv) + r2 = Rotation.from_quat(q, inversion=inv) + assert r2.approx_equal(r).all() + + +def test_improper_quat_reflection(): + """Test improper quaternions with reflection""" + r = Rotation.random(10, random_state=0, improper='random') + q, ref = r.as_quat(improper='reflection') + assert torch.equal(r.is_improper, ref) + r2 = Rotation.from_quat(q, reflection=ref) + assert r2.approx_equal(r).all() + + +def test_improper_quat_warn(): + """Test improper quaternions with warning""" + r = Rotation.random(10, random_state=0, improper=True) + with pytest.warns(UserWarning, match='Rotation contains improper'): + _ = r.as_quat(improper='warn') + + +def test_improper_euler_reflection(): + """Test improper euler angles with reflection""" + r = Rotation.random(10, random_state=0, improper=True) + angle, ref = r.as_euler('xyz', improper='reflection') + r2 = Rotation.from_euler('xyz', angle, reflection=ref) + assert r2.approx_equal(r, atol=1e-5).all() # loss of precision in reflection conversion + + +def test_improper_euler_inversion(): + """Test improper euler angles with inversion""" + r = Rotation.random(10, random_state=0, improper=True) + angle, inv = r.as_euler('xyz', improper='inversion') + r2 = Rotation.from_euler('xyz', angle, inversion=inv) + assert r2.approx_equal(r).all() + + +def test_improper_euler_warn(): + """Test improper euler angles with warning""" + r = Rotation.random(10, random_state=0, improper=True) + with pytest.warns(UserWarning, match='Rotation contains improper'): + _ = r.as_euler('xyz', improper='warn') + + +def test_improper_as_rotvec_reflection(): + """Test improper as_rotvec with reflection""" + r = Rotation.random(10, random_state=0, improper=True) + expected = r.reflect().as_rotvec() + actual, _ = r.as_rotvec(improper='reflection') + + torch.testing.assert_close(actual, expected, atol=1e-5, rtol=0) + + +def test_improper_from_rotvec_reflection(): + """Test improper from_rotvec with reflection""" + # Test the shortcut in from_rotvec + r = Rotation.random(10, random_state=0, improper=False) + rotvec = r.as_rotvec() + actual = Rotation.from_rotvec(rotvec, reflection=True) + expected = Rotation.from_rotvec(rotvec).reflect() + assert actual.approx_equal(expected).all() + + +def test_improper_rotvec_inversion(): + """Test improper rotvec with inversion""" + r = Rotation.random(10, random_state=0, improper=True) + rotvec, inv = r.as_rotvec(improper='inversion') + r2 = Rotation.from_rotvec(rotvec, inversion=inv) + assert r2.approx_equal(r).all() + + +def test_improper_rotvec_reflection(): + """Test improper rotvec with inversion""" + r = Rotation.random(1, random_state=0, improper=False) + rotvec = r.as_rotvec() + r2 = r.reflect() + r3 = Rotation.from_rotvec(rotvec, reflection=True) + assert r2.approx_equal(r3).all() + + +def test_improper_rotvec_warn(): + """Test improper rotvec with warning""" + r = Rotation.random(10, random_state=0, improper=True) + with pytest.warns(UserWarning, match='Rotation contains improper'): + _ = r.as_rotvec(improper='warn') + + +def test_apply_scipy(): + """Test apply to vector (scipy style apply)""" + r = Rotation.random(10, random_state=0) + v = RandomGenerator(0).float32_tensor(size=(10, 3)) + with pytest.warns(UserWarning, match='Consider using Rotation'): + actual = r.apply(v) + expected = (r.as_matrix() @ v.unsqueeze(-1)).squeeze(-1) + torch.testing.assert_close(expected, actual) + + +def test_apply_torch(): + """Test apply with callable (torch style apply)""" + r = Rotation.random(10, random_state=0) + r.apply(lambda x: x.double()) + assert r._quaternions.dtype == torch.float64 + + +def test_random_vmf_uniform(): + """Test random rotations with a uniform distribution""" + mean = torch.tensor([0, 0, 1.0]) + # vmf does not support a seed, as torch.distribution do not support it + prev_rng_state = torch.random.get_rng_state() + torch.manual_seed(0) + r = Rotation.random_vmf(10000, mean, kappa=0, sigma=math.inf) + torch.random.set_rng_state(prev_rng_state) + assert r.shape == (10000,) + assert r.mean().magnitude() < 0.1 + + +def test_random_vmf_peaked(): + """Test random rotations with a peaked distribution""" + mean = torch.tensor([0.0, 1.0, 0.0]) + # vmf does not support a seed, as torch.distribution do not support it + prev_rng_state = torch.random.get_rng_state() + torch.manual_seed(0) + r = Rotation.random_vmf(5000, mean, kappa=50, sigma=20) + torch.random.set_rng_state(prev_rng_state) + assert r.shape == (5000,) + torch.testing.assert_close(torch.linalg.cross(r.mean().as_rotvec(), mean), torch.zeros(3), atol=3e-3, rtol=0) + + +def test_apply_improper(): + """Test apply with improper rotations""" + r = Rotation.random(10, random_state=0, improper=False) + v = RandomGenerator(0).float32_tensor(size=(10, 3)) + actual = r.invert_axes()(v) + expected = (-r.as_matrix() @ v.unsqueeze(-1)).squeeze(-1) + torch.testing.assert_close(expected, actual) + + +def test_reshape(): + r = Rotation.random((1, 2, 3), random_state=0, improper=False) + reshaped = r.reshape(3, 2, 1, 1) + assert reshaped.shape == (3, 2, 1, 1) + assert r.shape == (1, 2, 3) + rereshaped = reshaped.reshape(1, 2, 3) + torch.testing.assert_close(r._quaternions, rereshaped._quaternions) diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index b6595b36d..61fd127df 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.data import SpatialDimension +from typing_extensions import Any, assert_type from tests import RandomGenerator @@ -25,9 +26,9 @@ def test_spatial_dimension_from_xyz_tensor(): """Test creation from an object with x, y, z attributes""" class XYZtensor: - x = 1 * torch.ones(1) - y = 2 * torch.ones(2) - z = 3 * torch.ones(3) + x = 1 * torch.ones(1, 2, 3) + y = 2 * torch.ones(1, 2, 3) + z = 3 * torch.ones(1, 2, 3) spatial_dimension = SpatialDimension.from_xyz(XYZtensor()) assert torch.equal(spatial_dimension.x, XYZtensor.x) @@ -53,6 +54,29 @@ def test_spatial_dimension_from_array(): assert torch.equal(spatial_dimension_xyz.z, spatial_dimension_zyx.z) +def test_from_array_arraylike(): + """Test creation from an ArrayLike list of list of int""" + xyz = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + spatial_dimension_xyz = SpatialDimension.from_array_xyz(list(zip(*xyz, strict=False))) + assert isinstance(spatial_dimension_xyz.x, torch.Tensor) + assert isinstance(spatial_dimension_xyz.y, torch.Tensor) + assert isinstance(spatial_dimension_xyz.z, torch.Tensor) + assert_type(spatial_dimension_xyz, SpatialDimension[torch.Tensor]) + assert torch.equal(spatial_dimension_xyz.x, torch.tensor(xyz[0])) + assert torch.equal(spatial_dimension_xyz.y, torch.tensor(xyz[1])) + assert torch.equal(spatial_dimension_xyz.z, torch.tensor(xyz[2])) + + spatial_dimension_zyx = SpatialDimension.from_array_zyx(list(zip(*xyz[::-1], strict=False))) + assert isinstance(spatial_dimension_xyz.x, torch.Tensor) + assert isinstance(spatial_dimension_xyz.y, torch.Tensor) + assert isinstance(spatial_dimension_xyz.z, torch.Tensor) + assert_type(spatial_dimension_zyx, SpatialDimension[torch.Tensor]) + assert torch.equal(spatial_dimension_zyx.x, torch.tensor(xyz[0])) + assert torch.equal(spatial_dimension_zyx.y, torch.tensor(xyz[1])) + assert torch.equal(spatial_dimension_zyx.z, torch.tensor(xyz[2])) + + def test_spatial_dimension_from_array_wrongshape(): """Test error message on wrong shape""" tensor_wrongshape = torch.zeros(1, 2, 5) @@ -60,23 +84,63 @@ def test_spatial_dimension_from_array_wrongshape(): _ = SpatialDimension.from_array_xyz(tensor_wrongshape) -def test_spatial_dimension_from_array_conversion(): - """Test conversion argument""" +def test_spatial_dimension_broadcasting(): + z = torch.ones(2, 1, 1) + y = torch.ones(1, 2, 1) + x = torch.ones(1, 1, 2) + spatial_dimension = SpatialDimension(z, y, x) + assert spatial_dimension.shape == (2, 2, 2) + + +def test_spatial_dimension_apply_(): + """Test apply_ (in place)""" def conversion(x: torch.Tensor) -> torch.Tensor: assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' return x.swapaxes(0, 1).square() xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension_xyz = SpatialDimension.from_array_xyz(xyz.numpy(), conversion=conversion) - assert isinstance(spatial_dimension_xyz.x, torch.Tensor) - assert isinstance(spatial_dimension_xyz.y, torch.Tensor) - assert isinstance(spatial_dimension_xyz.z, torch.Tensor) + spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) + spatial_dimension_inplace = spatial_dimension.apply_().apply_(conversion) + + assert spatial_dimension_inplace is spatial_dimension + + assert isinstance(spatial_dimension_inplace.x, torch.Tensor) + assert isinstance(spatial_dimension_inplace.y, torch.Tensor) + assert isinstance(spatial_dimension_inplace.z, torch.Tensor) x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_xyz.x, x) - assert torch.equal(spatial_dimension_xyz.y, y) - assert torch.equal(spatial_dimension_xyz.z, z) + assert torch.equal(spatial_dimension_inplace.x, x) + assert torch.equal(spatial_dimension_inplace.y, y) + assert torch.equal(spatial_dimension_inplace.z, z) + + +def test_spatial_dimension_apply(): + """Test apply (out of place)""" + + def conversion(x: torch.Tensor) -> torch.Tensor: + assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' + return x.swapaxes(0, 1).square() + + xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) + spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) + spatial_dimension_outofplace = spatial_dimension.apply(conversion) + + assert spatial_dimension_outofplace is not spatial_dimension + + assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) + assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) + assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) + + x, y, z = conversion(xyz).unbind(-1) + assert torch.equal(spatial_dimension_outofplace.x, x) + assert torch.equal(spatial_dimension_outofplace.y, y) + assert torch.equal(spatial_dimension_outofplace.z, z) + + x, y, z = xyz.unbind(-1) # original should be unmodified + assert torch.equal(spatial_dimension.x, x) + assert torch.equal(spatial_dimension.y, y) + assert torch.equal(spatial_dimension.z, z) def test_spatial_dimension_zyx(): @@ -104,6 +168,7 @@ def test_spatial_dimension_cuda_tensor(): assert not spatial_dimension.is_cuda +@pytest.mark.cuda def test_spatial_dimension_cuda_float(): """Test moving to CUDA without tensors -> copy only""" spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) @@ -119,3 +184,472 @@ def test_spatial_dimension_cuda_float(): assert spatial_dimension.device is None assert spatial_dimension_cuda.device is None assert spatial_dimension_cuda is not spatial_dimension + + +def test_spatial_dimension_getitem_tensor(): + """Test accessing elements of SpatialDimension.""" + zyx = RandomGenerator(0).float32_tensor((4, 2, 3)) + spatial_dimension = SpatialDimension.from_array_zyx(zyx) + torch.testing.assert_close(torch.stack(spatial_dimension[:2, ...].zyx, dim=-1), zyx[:2, ...]) + + +def test_spatial_dimension_setitem_tensor(): + """Test setting elements of SpatialDimension[torch.Tensor].""" + zyx = RandomGenerator(0).float32_tensor((4, 2, 3)) + spatial_dimension = SpatialDimension.from_array_zyx(zyx) + spatial_dimension_to_set = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension[2, 1] = spatial_dimension_to_set + assert spatial_dimension[2, 1].zyx == spatial_dimension_to_set.zyx + + +def test_spatial_dimension_mul(): + """Test multiplication of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_mul = spatial_dimension * value + assert isinstance(spatial_dimension_mul, SpatialDimension) + assert spatial_dimension_mul.zyx == ( + spatial_dimension.z * value, + spatial_dimension.y * value, + spatial_dimension.x * value, + ) + + +def test_spatial_dimension_rmul(): + """Test right multiplication of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_mul = value * spatial_dimension + assert isinstance(spatial_dimension_mul, SpatialDimension) + assert spatial_dimension_mul.zyx == ( + spatial_dimension.z * value, + spatial_dimension.y * value, + spatial_dimension.x * value, + ) + + +def test_spatial_dimension_mul_spatial_dimension(): + """Test multiplication of SpatialDimension with SpatialDimension.""" + spatial_dimension_1 = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_2 = SpatialDimension(z=4.0, y=5.0, x=6.0) + spatial_dimension_mul = spatial_dimension_1 * spatial_dimension_2 + assert isinstance(spatial_dimension_mul, SpatialDimension) + assert spatial_dimension_mul.zyx == ( + spatial_dimension_1.z * spatial_dimension_2.z, + spatial_dimension_1.y * spatial_dimension_2.y, + spatial_dimension_1.x * spatial_dimension_2.x, + ) + + +def test_spatial_dimension_truediv(): + """Test division of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_div = spatial_dimension / value + assert isinstance(spatial_dimension_div, SpatialDimension) + assert spatial_dimension_div.zyx == ( + spatial_dimension.z / value, + spatial_dimension.y / value, + spatial_dimension.x / value, + ) + + +def test_spatial_dimension_rtruediv(): + """Test right division of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_div = value / spatial_dimension + assert isinstance(spatial_dimension_div, SpatialDimension) + assert spatial_dimension_div.zyx == ( + value / spatial_dimension.z, + value / spatial_dimension.y, + value / spatial_dimension.x, + ) + + +def test_spatial_dimension_truediv_spatial_dimension(): + """Test divitions of SpatialDimension with SpatialDimension.""" + spatial_dimension_1 = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_2 = SpatialDimension(z=4.0, y=5.0, x=6.0) + spatial_dimension_div = spatial_dimension_1 / spatial_dimension_2 + assert isinstance(spatial_dimension_div, SpatialDimension) + assert spatial_dimension_div.zyx == ( + spatial_dimension_1.z / spatial_dimension_2.z, + spatial_dimension_1.y / spatial_dimension_2.y, + spatial_dimension_1.x / spatial_dimension_2.x, + ) + + +def test_spatial_dimension_add(): + """Test addition of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_add = spatial_dimension + value + assert isinstance(spatial_dimension_add, SpatialDimension) + assert spatial_dimension_add.zyx == ( + spatial_dimension.z + value, + spatial_dimension.y + value, + spatial_dimension.x + value, + ) + + +def test_spatial_dimension_radd(): + """Test right addition of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_add = value + spatial_dimension + assert isinstance(spatial_dimension_add, SpatialDimension) + assert spatial_dimension_add.zyx == ( + spatial_dimension.z + value, + spatial_dimension.y + value, + spatial_dimension.x + value, + ) + + +def test_spatial_dimension_add_spatial_dimension(): + """Test addition of SpatialDimension with SpatialDimension.""" + spatial_dimension_1 = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_2 = SpatialDimension(z=4.0, y=5.0, x=6.0) + spatial_dimension_add = spatial_dimension_1 + spatial_dimension_2 + assert isinstance(spatial_dimension_add, SpatialDimension) + assert spatial_dimension_add.zyx == ( + spatial_dimension_1.z + spatial_dimension_2.z, + spatial_dimension_1.y + spatial_dimension_2.y, + spatial_dimension_1.x + spatial_dimension_2.x, + ) + + +def test_spatial_dimension_sub(): + """Test subtraction of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_sub = spatial_dimension - value + assert isinstance(spatial_dimension_sub, SpatialDimension) + assert spatial_dimension_sub.zyx == ( + spatial_dimension.z - value, + spatial_dimension.y - value, + spatial_dimension.x - value, + ) + + +def test_spatial_dimension_rsub(): + """Test right subtraction of SpatialDimension with numeric value.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + value = 3 + spatial_dimension_sub = value - spatial_dimension + assert isinstance(spatial_dimension_sub, SpatialDimension) + assert spatial_dimension_sub.zyx == ( + value - spatial_dimension.z, + value - spatial_dimension.y, + value - spatial_dimension.x, + ) + + +def test_spatial_dimension_sub_spatial_dimension(): + """Test subtraction of SpatialDimension with SpatialDimension.""" + spatial_dimension_1 = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_2 = SpatialDimension(z=4.0, y=5.0, x=6.0) + spatial_dimension_sub = spatial_dimension_1 - spatial_dimension_2 + assert isinstance(spatial_dimension_sub, SpatialDimension) + assert spatial_dimension_sub.zyx == ( + spatial_dimension_1.z - spatial_dimension_2.z, + spatial_dimension_1.y - spatial_dimension_2.y, + spatial_dimension_1.x - spatial_dimension_2.x, + ) + + +def test_spatial_dimension_eq_float(): + """Test equality of SpatialDimension.""" + eq = SpatialDimension(z=1.0, y=2.0, x=3.0) == SpatialDimension(z=1, y=2, x=3) + assert_type(eq, bool) + assert eq + neq = SpatialDimension(z=1.0, y=2.0, x=3.0) == SpatialDimension(z=1.0, y=2.0, x=4.0) + assert_type(neq, bool) + assert not neq + + +def test_spatial_dimension_eq_tensor(): + """Test equality of SpatialDimension with tensors.""" + spatial_dimension_1 = SpatialDimension(z=torch.ones(2), y=torch.ones(2), x=torch.ones(2)) + spatial_dimension_2 = SpatialDimension(z=torch.ones(2), y=torch.ones(2), x=torch.arange(2)) + comp: Any = spatial_dimension_1 == spatial_dimension_2 + assert torch.equal(comp, torch.tensor([False, True])) + + +def test_spatial_dimension_comp_scalar(): + """Test equality of SpatialDimension.""" + assert SpatialDimension(1, 2, 3) > SpatialDimension(0, 0, 0) + assert SpatialDimension(1, 2, 3) >= SpatialDimension(1, 0, 0) + assert not SpatialDimension(1, 2, 3) > SpatialDimension(1, 0, 0) + + assert SpatialDimension(1, 2, 3) < SpatialDimension(10, 10, 10) + assert SpatialDimension(1, 2, 3) <= SpatialDimension(10, 10, 3) + assert not SpatialDimension(1, 2, 3) < SpatialDimension(10, 10, 3) + + +def test_spatial_dimension_comp_tensor(): + """Test equality of SpatialDimension.""" + t = torch.ones(2) + assert (SpatialDimension(1 * t, 2 * t, 3 * t) > SpatialDimension(0 * t, 0 * t, 0 * t)).all() + assert (SpatialDimension(1 * t, 2 * t, 3 * t) >= SpatialDimension(1 * t, 0 * t, 0 * t)).all() + assert not (SpatialDimension(1 * t, 2 * t, 3 * t) > SpatialDimension(1 * t, 0 * t, 0 * t)).any() + + assert (SpatialDimension(1 * t, 2 * t, 3 * t) < SpatialDimension(10 * t, 10 * t, 10 * t)).all() + assert (SpatialDimension(1 * t, 2 * t, 3 * t) <= SpatialDimension(10 * t, 10 * t, 3 * t)).all() + assert not (SpatialDimension(1 * t, 2 * t, 3 * t) < SpatialDimension(10 * t, 10 * t, 3 * t)).any() + + +def test_spatial_dimension_neg(): + """Test negation of SpatialDimension.""" + spatial_dimension = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_neg = -spatial_dimension + assert isinstance(spatial_dimension_neg, SpatialDimension) + assert spatial_dimension_neg.zyx == (-spatial_dimension.z, -spatial_dimension.y, -spatial_dimension.x) + + +def mypy_test_spatial_dimension_typing_add(): + """Test typing of SpatialDimension operations (mypy) + + This test checks that the typing of the operations is correct. + It will be used by pytest, but mypy will complain if any of the + types are wrong. + """ + spatial_dimension_float = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_int = SpatialDimension(z=1, y=2, x=3) + + spatial_dimension_tensor = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + scalar_float = 1.0 + scalar_int = 1 + scalar_tensor = torch.ones(1) + + # int + assert_type(spatial_dimension_int + spatial_dimension_int, SpatialDimension[int]) + assert_type(spatial_dimension_int + scalar_int, SpatialDimension[int]) + assert_type(scalar_int + spatial_dimension_int, SpatialDimension[int]) + + # tensor + assert_type(spatial_dimension_tensor + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor + scalar_tensor, SpatialDimension[torch.Tensor]) + # assert_type(scalar_tensor + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) # FIXME torch typing issue # noqa + + # float + assert_type(spatial_dimension_float + spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float + scalar_float, SpatialDimension[float]) + assert_type(scalar_float + spatial_dimension_float, SpatialDimension[float]) + + # int gets promoted to float + assert_type(spatial_dimension_int + spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float + spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int + scalar_float, SpatialDimension[float]) + assert_type(scalar_float + spatial_dimension_int, SpatialDimension[float]) + + # int or float gets promoted to tensor + assert_type(spatial_dimension_int + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor + spatial_dimension_int, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_float + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor + spatial_dimension_float, SpatialDimension[torch.Tensor]) + + assert_type(scalar_int + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor + scalar_int, SpatialDimension[torch.Tensor]) + assert_type(scalar_float + spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor + scalar_float, SpatialDimension[torch.Tensor]) + + +def mypy_test_spatial_dimension_typing_sub(): + """Test typing of SpatialDimension operations (mypy) + + This test checks that the typing of the operations is correct. + It will be used by pytest, but mypy will complain if any of the + types are wrong. + """ + spatial_dimension_float = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_int = SpatialDimension(z=1, y=2, x=3) + + spatial_dimension_tensor = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + scalar_float = 1.0 + scalar_int = 1 + scalar_tensor = torch.ones(1) + + # int + assert_type(spatial_dimension_int - spatial_dimension_int, SpatialDimension[int]) + assert_type(spatial_dimension_int - scalar_int, SpatialDimension[int]) + assert_type(scalar_int - spatial_dimension_int, SpatialDimension[int]) + + # tensor + assert_type(spatial_dimension_tensor - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor - scalar_tensor, SpatialDimension[torch.Tensor]) + # assert_type(scalar_tensor - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) # FIXME torch typing issue # noqa + + # float + assert_type(spatial_dimension_float - spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float - scalar_float, SpatialDimension[float]) + assert_type(scalar_float - spatial_dimension_float, SpatialDimension[float]) + + # int gets promoted to float + assert_type(spatial_dimension_int - spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float - spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int - scalar_float, SpatialDimension[float]) + assert_type(scalar_float - spatial_dimension_int, SpatialDimension[float]) + + # int or float gets promoted to tensor + assert_type(spatial_dimension_int - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor - spatial_dimension_int, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_float - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor - spatial_dimension_float, SpatialDimension[torch.Tensor]) + + assert_type(scalar_int - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor - scalar_int, SpatialDimension[torch.Tensor]) + assert_type(scalar_float - spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor - scalar_float, SpatialDimension[torch.Tensor]) + + +def mypy_test_spatial_dimension_typing_floordiv(): + """Test typing of SpatialDimension operations (mypy) + + This test checks that the typing of the operations is correct. + It will be used by pytest, but mypy will complain if any of the + types are wrong. + """ + spatial_dimension_float = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_int = SpatialDimension(z=1, y=2, x=3) + + spatial_dimension_tensor = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + scalar_float = 1.0 + scalar_int = 1 + scalar_tensor = torch.ones(1) + + # int + assert_type(spatial_dimension_int // spatial_dimension_int, SpatialDimension[int]) + assert_type(spatial_dimension_int // scalar_int, SpatialDimension[int]) + assert_type(scalar_int // spatial_dimension_int, SpatialDimension[int]) + + # tensor + assert_type(spatial_dimension_tensor // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor // scalar_tensor, SpatialDimension[torch.Tensor]) + # assert_type(scalar_tensor // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) # FIXME torch typing issue # noqa + + # float + assert_type(spatial_dimension_float // spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float // scalar_float, SpatialDimension[float]) + assert_type(scalar_float // spatial_dimension_float, SpatialDimension[float]) + + # int gets promoted to float + assert_type(spatial_dimension_int // spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float // spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int // scalar_float, SpatialDimension[float]) + assert_type(scalar_float // spatial_dimension_int, SpatialDimension[float]) + + # int or float gets promoted to tensor + assert_type(spatial_dimension_int // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor // spatial_dimension_int, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_float // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor // spatial_dimension_float, SpatialDimension[torch.Tensor]) + + assert_type(scalar_int // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor // scalar_int, SpatialDimension[torch.Tensor]) + assert_type(scalar_float // spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor // scalar_float, SpatialDimension[torch.Tensor]) + + +def mypy_test_spatial_dimension_typing_mul(): + """Test typing of SpatialDimension operations (mypy) + + This test checks that the typing of the operations is correct. + It will be used by pytest, but mypy will complain if any of the + types are wrong. + """ + spatial_dimension_float = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_int = SpatialDimension(z=1, y=2, x=3) + + spatial_dimension_tensor = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + scalar_float = 1.0 + scalar_int = 1 + scalar_tensor = torch.ones(1) + + # int + assert_type(spatial_dimension_int * spatial_dimension_int, SpatialDimension[int]) + assert_type(spatial_dimension_int * scalar_int, SpatialDimension[int]) + assert_type(scalar_int * spatial_dimension_int, SpatialDimension[int]) + + # tensor + assert_type(spatial_dimension_tensor * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor * scalar_tensor, SpatialDimension[torch.Tensor]) + # assert_type(scalar_tensor * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) # FIXME torch typing issue # noqa + + # float + assert_type(spatial_dimension_float * spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float * scalar_float, SpatialDimension[float]) + assert_type(scalar_float * spatial_dimension_float, SpatialDimension[float]) + + # int gets promoted to float + assert_type(spatial_dimension_int * spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float * spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int * scalar_float, SpatialDimension[float]) + assert_type(scalar_float * spatial_dimension_int, SpatialDimension[float]) + + # int or float gets promoted to tensor + assert_type(spatial_dimension_int * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor * spatial_dimension_int, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_float * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor * spatial_dimension_float, SpatialDimension[torch.Tensor]) + + assert_type(scalar_int * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor * scalar_int, SpatialDimension[torch.Tensor]) + assert_type(scalar_float * spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor * scalar_float, SpatialDimension[torch.Tensor]) + + +def mypy_test_spatial_dimension_typing_truediv(): + """Test typing of SpatialDimension operations (mypy) + + This test checks that the typing of the operations is correct. + It will be used by pytest, but mypy will complain if any of the + types are wrong. + """ + spatial_dimension_float = SpatialDimension(z=1.0, y=2.0, x=3.0) + spatial_dimension_int = SpatialDimension(z=1, y=2, x=3) + + spatial_dimension_tensor = SpatialDimension(z=torch.ones(1), y=torch.ones(1), x=torch.ones(1)) + scalar_float = 1.0 + scalar_int = 1 + scalar_tensor = torch.ones(1) + + # tensor + assert_type(spatial_dimension_tensor / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor / scalar_tensor, SpatialDimension[torch.Tensor]) + # assert_type(scalar_tensor / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) # FIXME torch typing issue # noqa + + # float + assert_type(spatial_dimension_float / spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float / scalar_float, SpatialDimension[float]) + assert_type(scalar_float / spatial_dimension_float, SpatialDimension[float]) + + # int gets promoted to float + assert_type(spatial_dimension_int / spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int / scalar_int, SpatialDimension[float]) + assert_type(scalar_int / spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int / spatial_dimension_float, SpatialDimension[float]) + assert_type(spatial_dimension_float / spatial_dimension_int, SpatialDimension[float]) + assert_type(spatial_dimension_int / scalar_float, SpatialDimension[float]) + assert_type(scalar_float / spatial_dimension_int, SpatialDimension[float]) + + # int or float gets promoted to tensor + assert_type(spatial_dimension_int / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor / spatial_dimension_int, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_float / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor / spatial_dimension_float, SpatialDimension[torch.Tensor]) + + assert_type(scalar_int / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor / scalar_int, SpatialDimension[torch.Tensor]) + assert_type(scalar_float / spatial_dimension_tensor, SpatialDimension[torch.Tensor]) + assert_type(spatial_dimension_tensor / scalar_float, SpatialDimension[torch.Tensor]) + + +def test_spatial_dimension_masked_inplace_add(): + """Test inplace add and masking.""" + spatial_dimension = SpatialDimension(z=torch.arange(3), y=torch.arange(3), x=torch.arange(3)) + mask = spatial_dimension < SpatialDimension(z=torch.tensor(1), y=torch.tensor(1), x=torch.tensor(1)) + spatial_dimension[mask] += SpatialDimension(z=1, y=2, x=3) + assert torch.equal(spatial_dimension.z, torch.tensor([1, 1, 2])) + assert torch.equal(spatial_dimension.y, torch.tensor([2, 1, 2])) + assert torch.equal(spatial_dimension.x, torch.tensor([3, 1, 2])) diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 1baf4340b..1061a93be 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -147,16 +147,16 @@ def test_trajectory_cpu(cartesian_grid): @COMMON_MR_TRAJECTORIES -def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the kz, ky and kz dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_kzyx[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) @@ -171,16 +171,16 @@ def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, @COMMON_MR_TRAJECTORIES -def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the k2, k1 and k0 dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_k210[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) diff --git a/tests/helper.py b/tests/helper.py index 794f5685a..7e11826a7 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -1,7 +1,8 @@ """Helper/Utilities for test functions.""" import torch -from mrpro.operators import Operator +from mrpro.operators import LinearOperator, Operator +from typing_extensions import TypeVarTuple, Unpack def relative_image_difference(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: @@ -26,9 +27,13 @@ def relative_image_difference(img1: torch.Tensor, img2: torch.Tensor) -> torch.T def dotproduct_adjointness_test( - operator: Operator, u: torch.Tensor, v: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 + operator: LinearOperator, + u: torch.Tensor, + v: torch.Tensor, + relative_tolerance: float = 1e-3, + absolute_tolerance=1e-5, ): - """Test the adjointness of operator and operator.H. + """Test the adjointness of linear operator and operator.H. Test if == @@ -42,7 +47,7 @@ def dotproduct_adjointness_test( Parameters ---------- operator - operator + linear operator u element of the domain of the operator v @@ -74,9 +79,12 @@ def dotproduct_adjointness_test( def operator_isometry_test( - operator: Operator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 + operator: Operator[torch.Tensor, tuple[torch.Tensor]], + u: torch.Tensor, + relative_tolerance: float = 1e-3, + absolute_tolerance=1e-5, ): - """Test the isometry of an operator. + """Test the isometry of a operator. Test if ||Operator(u)|| == ||u|| @@ -103,10 +111,10 @@ def operator_isometry_test( ) -def operator_unitary_test( - operator: Operator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 +def linear_operator_unitary_test( + operator: LinearOperator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 ): - """Test if an operator is unitary. + """Test if a linear operator is unitary. Test if Operator.adjoint(Operator(u)) == u @@ -115,7 +123,7 @@ def operator_unitary_test( Parameters ---------- operator - operator + linear operator u element of the domain of the operator relative_tolerance @@ -129,3 +137,36 @@ def operator_unitary_test( if the adjointness property does not hold """ torch.testing.assert_close(u, operator.adjoint(operator(u)[0])[0], rtol=relative_tolerance, atol=absolute_tolerance) + + +Tin = TypeVarTuple('Tin') + + +def autodiff_test( + operator: Operator[Unpack[Tin], tuple[torch.Tensor, ...]], + *u: Unpack[Tin], +): + """Test if autodiff of an operator is working. + This test does not check that the gradient is correct but simply that it can be calculated using both torch.func.jvp + and torch.func.vjp. + + Parameters + ---------- + operator + operator + u + element(s) of the domain of the operator + + Raises + ------ + AssertionError + if autodiff fails + """ + # Forward-mode autodiff using jvp + with torch.autograd.detect_anomaly(): + v_range, _ = torch.func.jvp(operator.forward, u, u) + + # Backward-mode autodiff using vjp + with torch.autograd.detect_anomaly(): + (_, vjpfunc) = torch.func.vjp(operator.forward, *u) + vjpfunc(v_range) diff --git a/tests/operators/functionals/__init__.py b/tests/operators/functionals/__init__.py index e9d7aa091..878750b62 100644 --- a/tests/operators/functionals/__init__.py +++ b/tests/operators/functionals/__init__.py @@ -1,4 +1,5 @@ from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal +from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L2NormSquared import L2NormSquared -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional diff --git a/tests/operators/functionals/conftest.py b/tests/operators/functionals/conftest.py index 9d2bfd8c9..ae8ffc20b 100644 --- a/tests/operators/functionals/conftest.py +++ b/tests/operators/functionals/conftest.py @@ -47,12 +47,12 @@ def result_dtype(self): def functional_test_cases(func: Callable[[FunctionalTestCase], None]) -> Callable[..., None]: """Decorator combining multiple parameterizations for test cases for all proximable functionals.""" - @pytest.mark.parametrize('shape', [[1, 2, 3]]) + @pytest.mark.parametrize('shape', [[1, 2, 3]], ids=['shape=[1,2,3]']) @pytest.mark.parametrize('dtype_name', ['float32', 'complex64']) @pytest.mark.parametrize('weight', ['scalar_weight', 'tensor_weight', 'complex_weight']) @pytest.mark.parametrize('target', ['no_target', 'random_target']) - @pytest.mark.parametrize('dim', [None]) - @pytest.mark.parametrize('divide_by_n', [True, False]) + @pytest.mark.parametrize('dim', [None], ids=['dim=None']) + @pytest.mark.parametrize('divide_by_n', [True, False], ids=['mean', 'sum']) @pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) def wrapper( functional: type[ElementaryProximableFunctional], diff --git a/tests/operators/functionals/test_functionals.py b/tests/operators/functionals/test_functionals.py index e0b9efd65..39b16da12 100644 --- a/tests/operators/functionals/test_functionals.py +++ b/tests/operators/functionals/test_functionals.py @@ -1,10 +1,11 @@ from copy import deepcopy -from typing import Literal, TypedDict +from typing import Literal import pytest import torch from mrpro.operators.Functional import ElementaryFunctional, ElementaryProximableFunctional -from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional +from mrpro.operators.functionals import MSE, L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional +from typing_extensions import TypedDict from tests import RandomGenerator from tests.operators.functionals.conftest import ( @@ -296,6 +297,19 @@ class NumericCase(TypedDict): [[[-2.983529, -1.943529, -1.049412], [-0.108235, 1.468235, 1.971765]]] ), }, + 'MSE': { + # Generated with ODL + 'functional': MSE, + 'x': torch.tensor([[[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]]]), + 'weight': 2.0, + 'target': torch.tensor([[[0.340, 0.130, 0.230], [0.230, -1.120, -0.190]]]), + 'sigma': 0.5, + 'fx_expected': torch.tensor(17.6992), + 'prox_expected': torch.tensor([[[-1.6640, -1.1480, -0.5080], [0.0920, 0.1520, 1.1240]]]), + 'prox_convex_conj_expected': torch.tensor( + [[[-2.305455, -1.501818, -0.810909], [-0.083636, 1.134545, 1.523636]]] + ), + }, } diff --git a/tests/operators/functionals/test_mse_functional.py b/tests/operators/functionals/test_mse_functional.py deleted file mode 100644 index 3ff9b1ee8..000000000 --- a/tests/operators/functionals/test_mse_functional.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Tests for MSE-functional.""" - -import pytest -import torch -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy - - -@pytest.mark.parametrize( - ('data', 'x', 'expected_mse'), - [ - ((0.0, 0.0), (0.0, 0.0), (0.0)), # zero-tensors deliver 0-error - ((0.0 + 1j * 0, 0.0), (0.0 + 1j * 0, 0.0), (0.0)), # zero-tensors deliver 0-error; complex-valued - ((1.0, 0.0), (1.0, 0.0), (0.0)), # same tensors; both real-valued - ((1.0, 0.0), (1.0 + 1j * 0, 0.0), (0.0)), # same tensors; input complex-valued - ((1.0, 0.0), (1.0 + 1j * 1, 0.0), (0.5)), # different tensors; input complex-valued - ((1.0 + 1j * 0, 0.0), (1.0, 0.0), (0.0)), # same tensors; data complex-valued - ((1.0 + 1j * 1, 0.0), (1.0, 0.0), (0.5)), # different tensors; data complex-valued - ((1.0 + 1j * 0, 0.0), (1.0 + 1j * 0, 0.0), (0.0)), # same tensors; both complex-valued with imag part=0 - ((1.0 + 1j * 1, 0.0), (1.0 + 1j * 1, 0.0), (0.0)), # same tensors; both complex-valued with imag part>0 - ((0.0 + 1j * 1, 0.0), (0.0 + 1j * 1, 0.0), (0.0)), # same tensors; both complex-valued with real part=0 - ], -) -def test_mse_functional(data, x, expected_mse): - """Test if mse_data_discrepancy matches expected values. - - Expected values are supposed to be - 1/N*|| . - data||_2^2 - """ - - mse_op = MSEDataDiscrepancy(torch.tensor(data)) - (mse,) = mse_op(torch.tensor(x)) - torch.testing.assert_close(mse, torch.tensor(expected_mse)) diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 570fa2f1f..4aab81ae0 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -8,7 +8,7 @@ SHAPE_VARIATIONS_SIGNAL_MODELS = pytest.mark.parametrize( ('parameter_shape', 'contrast_dim_shape', 'signal_shape'), [ - ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different inversion times + ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different contrast times ((1, 1, 10, 20, 30), (5, 1), (5, 1, 1, 10, 20, 30)), ((4, 1, 1, 10, 20, 30), (5, 1), (5, 4, 1, 1, 10, 20, 30)), # multiple maps along additional batch dimension ((4, 1, 1, 10, 20, 30), (5,), (5, 4, 1, 1, 10, 20, 30)), @@ -25,10 +25,30 @@ ((1,), (5,), (5, 1)), # single voxel ((4, 3, 1), (5, 4, 3), (5, 4, 3, 1)), ], + ids=[ + 'single_map_diff_contrast_times', + 'single_map_diff_contrast_times_2', + 'multiple_maps_additional_batch_dim', + 'multiple_maps_additional_batch_dim_2', + 'multiple_maps_additional_batch_dim_3', + 'multiple_maps_other_dim', + 'multiple_maps_other_dim_2', + 'multiple_maps_other_dim_3', + 'multiple_maps_other_and_batch_dim', + 'multiple_maps_other_and_batch_dim_2', + 'multiple_maps_other_and_batch_dim_3', + 'multiple_maps_other_and_batch_dim_4', + 'multiple_maps_other_and_batch_dim_5', + 'different_value_each_voxel', + 'single_voxel', + 'multiple_voxels', + ], ) -def create_parameter_tensor_tuples(parameter_shape=(10, 5, 100, 100, 100), number_of_tensors=2): +def create_parameter_tensor_tuples( + parameter_shape=(10, 5, 100, 100, 100), number_of_tensors=2 +) -> tuple[torch.Tensor, ...]: """Create tuples of tensors as input to operators.""" random_generator = RandomGenerator(seed=0) parameter_tensors = random_generator.float32_tensor(size=(number_of_tensors, *parameter_shape), low=1e-10) diff --git a/tests/operators/models/test_inversion_recovery.py b/tests/operators/models/test_inversion_recovery.py index 71bb4dfe6..b3d32a211 100644 --- a/tests/operators/models/test_inversion_recovery.py +++ b/tests/operators/models/test_inversion_recovery.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import InversionRecovery +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -21,7 +22,7 @@ def test_inversion_recovery(ti, result): """ model = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) # Assert closeness to -m0 for ti=0 if result == '-m0': @@ -37,5 +38,12 @@ def test_inversion_recovery_shape(parameter_shape, contrast_dim_shape, signal_sh (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = InversionRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape + + +def test_autodiff_inversion_recovery(): + """Test autodiff works for inversion_recovery model.""" + model = InversionRecovery(ti=10) + m0, t1 = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, t1) diff --git a/tests/operators/models/test_molli.py b/tests/operators/models/test_molli.py index 8ee5f9117..82b5f6c04 100644 --- a/tests/operators/models/test_molli.py +++ b/tests/operators/models/test_molli.py @@ -3,7 +3,7 @@ import pytest import torch from mrpro.operators.models import MOLLI -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -27,7 +27,7 @@ def test_molli(ti, result): # Generate signal model and torch tensor for comparison model = MOLLI(ti) - (image,) = model.forward(a, c, t1) + (image,) = model(a, c, t1) # Assert closeness to a(1-c) for large ti if result == 'a(1-c)': @@ -43,5 +43,12 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape): (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MOLLI(ti) a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(a, c, t1) + (signal,) = model_op(a, c, t1) assert signal.shape == signal_shape + + +def test_autodiff_molli(): + """Test autodiff works for molli model.""" + model = MOLLI(ti=10) + a, b, t1 = create_parameter_tensor_tuples((2, 5, 10, 10, 10), number_of_tensors=3) + autodiff_test(model, a, b, t1) diff --git a/tests/operators/models/test_mono_exponential_decay.py b/tests/operators/models/test_mono_exponential_decay.py index c56e86e12..1aba27891 100644 --- a/tests/operators/models/test_mono_exponential_decay.py +++ b/tests/operators/models/test_mono_exponential_decay.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import MonoExponentialDecay +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -21,7 +22,7 @@ def test_mono_exponential_decay(decay_time, result): """ model = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples() - (image,) = model.forward(m0, decay_constant) + (image,) = model(m0, decay_constant) zeros = torch.zeros_like(m0) @@ -39,5 +40,12 @@ def test_mono_exponential_decay_shape(parameter_shape, contrast_dim_shape, signa (decay_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = MonoExponentialDecay(decay_time) m0, decay_constant = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, decay_constant) + (signal,) = model_op(m0, decay_constant) assert signal.shape == signal_shape + + +def test_autodiff_exponential_decay(): + """Test autodiff works for mono-exponential decay model.""" + model = MonoExponentialDecay(decay_time=20) + m0, decay_constant = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, decay_constant) diff --git a/tests/operators/models/test_saturation_recovery.py b/tests/operators/models/test_saturation_recovery.py index 1a821282b..0b2406ee6 100644 --- a/tests/operators/models/test_saturation_recovery.py +++ b/tests/operators/models/test_saturation_recovery.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import SaturationRecovery +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -21,7 +22,7 @@ def test_saturation_recovery(ti, result): """ model = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples() - (image,) = model.forward(m0, t1) + (image,) = model(m0, t1) zeros = torch.zeros_like(m0) @@ -39,5 +40,12 @@ def test_saturation_recovery_shape(parameter_shape, contrast_dim_shape, signal_s (ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = SaturationRecovery(ti) m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) - (signal,) = model_op.forward(m0, t1) + (signal,) = model_op(m0, t1) assert signal.shape == signal_shape + + +def test_autodiff_aturation_recovery(): + """Test autodiff works for aturation recovery model.""" + model = SaturationRecovery(ti=10) + m0, t1 = create_parameter_tensor_tuples((2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, t1) diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index 87d21a336..5d43f8eb8 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -4,6 +4,7 @@ import torch from einops import repeat from mrpro.operators.models import TransientSteadyStateWithPreparation +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -23,7 +24,7 @@ def test_transient_steady_state(sampling_time, m0_scaling_preparation, result): repetition_time = 5 model = TransientSteadyStateWithPreparation(sampling_time, repetition_time, m0_scaling_preparation) m0, t1, flip_angle = create_parameter_tensor_tuples(number_of_tensors=3) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) # Assert closeness to m0 if result == 'm0': @@ -56,7 +57,7 @@ def test_transient_steady_state_inversion_recovery(): analytical_signal = m0 * (1 - 2 * torch.exp(-(sampling_time / t1))) model = TransientSteadyStateWithPreparation(sampling_time, repetition_time=100, m0_scaling_preparation=-1) - (signal,) = model.forward(m0, t1, flip_angle) + (signal,) = model(m0, t1, flip_angle) torch.testing.assert_close(signal, analytical_signal) @@ -66,9 +67,9 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa """Test correct signal shapes.""" (sampling_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) if len(parameter_shape) == 1: - repetition_time = 5 - m0_scaling_preparation = 1 - delay_after_preparation = 0.01 + repetition_time: float | torch.Tensor = 5 + m0_scaling_preparation: float | torch.Tensor = 1 + delay_after_preparation: float | torch.Tensor = 0.01 else: repetition_time, m0_scaling_preparation, delay_after_preparation = create_parameter_tensor_tuples( contrast_dim_shape[1:], number_of_tensors=3 @@ -77,5 +78,19 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation ) m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(m0, t1, flip_angle) + (signal,) = model_op(m0, t1, flip_angle) assert signal.shape == signal_shape + + +def test_autodiff_transient_steady_state(): + """Test autodiff works for transient steady state model.""" + contrast_dim_shape = (6,) + (sampling_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) + repetition_time, m0_scaling_preparation, delay_after_preparation = create_parameter_tensor_tuples( + contrast_dim_shape[1:], number_of_tensors=3 + ) + model = TransientSteadyStateWithPreparation( + sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation + ) + m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=3) + autodiff_test(model, m0, t1, flip_angle) diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 1aa36384e..3e58e0a05 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -2,10 +2,13 @@ import torch from mrpro.operators.models import WASABI +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples -def create_data(offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, c=1.0, d=2.0): +def create_data( + offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, c=1.0, d=2.0 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: offsets = torch.linspace(-offset_max, offset_max, n_offsets) return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([c]), torch.Tensor([d]) @@ -14,14 +17,14 @@ def test_WASABI_shift(): """Test symmetry property of shifted and unshifted WASABI spectra.""" offsets_unshifted, b0_shift, rb1, c, d = create_data() wasabi_model = WASABI(offsets=offsets_unshifted) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) offsets_shifted, b0_shift, rb1, c, d = create_data(b0_shift=100) wasabi_model = WASABI(offsets=offsets_shifted) - (signal_shifted,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal_shifted,) = wasabi_model(b0_shift, rb1, c, d) - lower_index = (offsets_shifted == -300).nonzero()[0][0].item() - upper_index = (offsets_shifted == 500).nonzero()[0][0].item() + lower_index = int((offsets_shifted == -300).nonzero()[0][0]) + upper_index = int((offsets_shifted == 500).nonzero()[0][0]) assert signal[0] == signal[-1], 'Result should be symmetric around center' assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift' @@ -30,7 +33,7 @@ def test_WASABI_shift(): def test_WASABI_extreme_offset(): offset, b0_shift, rb1, c, d = create_data(offset_max=30000, n_offsets=1) wasabi_model = WASABI(offsets=offset) - (signal,) = wasabi_model.forward(b0_shift, rb1, c, d) + (signal,) = wasabi_model(b0_shift, rb1, c, d) assert torch.isclose(signal, torch.tensor([1.0])), 'For an extreme offset, the signal should be unattenuated' @@ -41,5 +44,12 @@ def test_WASABI_shape(parameter_shape, contrast_dim_shape, signal_shape): (offsets,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) model_op = WASABI(offsets) b0_shift, rb1, c, d = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=4) - (signal,) = model_op.forward(b0_shift, rb1, c, d) + (signal,) = model_op(b0_shift, rb1, c, d) assert signal.shape == signal_shape + + +def test_autodiff_WASABI(): + """Test autodiff works for WASABI model.""" + offset, b0_shift, rb1, c, d = create_data(offset_max=300, n_offsets=2) + wasabi_model = WASABI(offsets=offset) + autodiff_test(wasabi_model, b0_shift, rb1, c, d) diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index bd451a03b..637f9ff9e 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -3,10 +3,13 @@ import pytest import torch from mrpro.operators.models import WASABITI +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples -def create_data(offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0): +def create_data( + offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: offsets = torch.linspace(-offset_max, offset_max, n_offsets) return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([t1]) @@ -15,7 +18,7 @@ def test_WASABITI_symmetry(): """Test symmetry property of complete WASABITI spectra.""" offsets, b0_shift, rb1, t1 = create_data() wasabiti_model = WASABITI(offsets=offsets, trec=torch.ones_like(offsets)) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) # check that all values are symmetric around the center assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center' @@ -26,10 +29,10 @@ def test_WASABITI_symmetry_after_shift(): offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100) trec = torch.ones_like(offsets_shifted) wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) - (signal_shifted,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) - lower_index = (offsets_shifted == -300).nonzero()[0][0].item() - upper_index = (offsets_shifted == 500).nonzero()[0][0].item() + lower_index = int((offsets_shifted == -300).nonzero()[0][0]) + upper_index = int((offsets_shifted == 500).nonzero()[0][0]) assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift' @@ -42,7 +45,7 @@ def test_WASABITI_asymmetry_for_non_unique_trec(): trec[: len(offsets_unshifted) // 2] = 2.0 wasabiti_model = WASABITI(offsets=offsets_unshifted, trec=trec) - (signal,) = wasabiti_model.forward(b0_shift, rb1, t1) + (signal,) = wasabiti_model(b0_shift, rb1, t1) assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center' @@ -53,7 +56,7 @@ def test_WASABITI_relaxation_term(t1): offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) trec = torch.ones_like(offset) * t1 wasabiti_model = WASABITI(offsets=offset, trec=trec) - sig = wasabiti_model.forward(b0_shift, rb1, t1) + sig = wasabiti_model(b0_shift, rb1, t1) assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8) @@ -72,5 +75,13 @@ def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) model_op = WASABITI(ti, trec) b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op.forward(b0_shift, rb1, t1) + (signal,) = model_op(b0_shift, rb1, t1) assert signal.shape == signal_shape + + +def test_autodiff_WASABITI(): + """Test autodiff works for WASABITI model.""" + offset, b0_shift, rb1, t1 = create_data(offset_max=300, n_offsets=2) + trec = torch.ones_like(offset) * t1 + wasabiti_model = WASABITI(offsets=offset, trec=trec) + autodiff_test(wasabiti_model, b0_shift, rb1, t1) diff --git a/tests/operators/test_autograd_linop.py b/tests/operators/test_autograd_linop.py index f2df2d1af..da833de98 100644 --- a/tests/operators/test_autograd_linop.py +++ b/tests/operators/test_autograd_linop.py @@ -5,8 +5,7 @@ from mrpro.operators import LinearOperator from torch.autograd.gradcheck import GradcheckError -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test class NonDifferentiableOperator(LinearOperator, adjoint_as_backward=False): diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..0fd320212 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,12 +2,13 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp +from typing_extensions import Unpack -from tests import RandomGenerator +from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import create_traj -from tests.helper import dotproduct_adjointness_test def test_cart_sampling_op_data_match(): @@ -16,10 +17,10 @@ def test_cart_sampling_op_data_match(): nkx = (1, 1, 1, 60) nky = (1, 1, 40, 1) nkz = (1, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + type_kx = 'uniform' + type_ky = 'uniform' + type_kz = 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Create matching data random_generator = RandomGenerator(seed=0) @@ -50,34 +51,15 @@ def test_cart_sampling_op_data_match(): torch.testing.assert_close(kdata[:, :, ::2, ::4, ::3], k_sub[:, :, ::2, ::4, ::3]) -@pytest.mark.parametrize( - 'sampling', - [ - 'random', - 'partial_echo', - 'partial_fourier', - 'regular_undersampling', - 'random_undersampling', - 'different_random_undersampling', - ], -) -def test_cart_sampling_op_fwd_adj(sampling): - """Test adjoint property of Cartesian sampling operator.""" - - # Create 3D uniform trajectory - k_shape = (2, 5, 20, 40, 60) - nkx = (2, 1, 1, 60) - nky = (2, 1, 40, 1) - nkz = (2, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() - +def subsample_traj( + trajectory: KTrajectory, sampling: str, k_shape: tuple[int, int, int, Unpack[tuple[int, ...]]] +) -> KTrajectory: + """Subsample trajectory based on sampling type.""" + trajectory_tensor = trajectory.as_tensor() # Subsample data and trajectory match sampling: case 'random': - random_idx = torch.randperm(k_shape[-2]) + random_idx = RandomGenerator(13).randperm(k_shape[-2]) trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx, :]) case 'partial_echo': trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-1] // 2]) @@ -86,16 +68,55 @@ def test_cart_sampling_op_fwd_adj(sampling): case 'regular_undersampling': trajectory = KTrajectory.from_tensor(trajectory_tensor[..., ::3, ::5, :]) case 'random_undersampling': - random_idx = torch.randperm(k_shape[-2]) + random_idx = RandomGenerator(13).randperm(k_shape[-2]) trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: k_shape[-2] // 2], :]) case 'different_random_undersampling': traj_list = [ - traj_one_other[..., torch.randperm(k_shape[-2])[: k_shape[-2] // 2], :] + traj_one_other[..., RandomGenerator(13).randperm(k_shape[-2])[: k_shape[-2] // 2], :] for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = RandomGenerator(13).randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') + return trajectory + + +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_fwd_adj(sampling): + """Test adjoint property of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) @@ -105,3 +126,64 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_gram(sampling): + """Test adjoint gram of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) + + encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + u = RandomGenerator(seed=0).complex64_tensor(size=k_shape) + (expected,) = (sampling_op.H @ sampling_op)(u) + (actual,) = sampling_op.gram(u) + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) +@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) +def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): + """Test trajectory points outside of encoding_matrix.""" + encoding_matrix = SpatialDimension(40, 1, 20) + + # Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side + # The indices are inverted to ensure CartesianSamplingOp acts on them + kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx') + ky = torch.ones(1, 1, 1, 1) + kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1') + kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements + trajectory = KTrajectory(kz=kz, ky=ky, kx=kx) + + with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'): + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + + random_generator = RandomGenerator(seed=0) + u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1])) + v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx)) + + assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx + assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1]) diff --git a/tests/operators/test_constraints_op.py b/tests/operators/test_constraints_op.py index 5d0ba55b0..b11f8f6d2 100644 --- a/tests/operators/test_constraints_op.py +++ b/tests/operators/test_constraints_op.py @@ -4,7 +4,7 @@ import torch from mrpro.operators import ConstraintsOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test @pytest.mark.parametrize( @@ -141,3 +141,15 @@ def test_constraints_operator_multiple_inputs(bounds): def test_constraints_operator_illegal_bounds(bounds): with pytest.raises(ValueError, match='invalid'): ConstraintsOp(bounds) + + +def test_autodiff_constraints_operator(): + """Test autodiff works for constraints operator.""" + # random tensors with arbitrary values + random_generator = RandomGenerator(seed=0) + x1 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + x2 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + x3 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + + constraints_op = ConstraintsOp(bounds=((None, None), (1.0, None), (None, 1.0))) + autodiff_test(constraints_op, x1, x2, x3) diff --git a/tests/operators/test_density_compensation_op.py b/tests/operators/test_density_compensation_op.py index 616d0e8f9..96b59547e 100644 --- a/tests/operators/test_density_compensation_op.py +++ b/tests/operators/test_density_compensation_op.py @@ -4,8 +4,7 @@ from mrpro.data import DcfData from mrpro.operators import DensityCompensationOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_density_compensation_op_adjointness(): diff --git a/tests/operators/test_einsum_op.py b/tests/operators/test_einsum_op.py index 1db7c0ac3..8e098ab32 100644 --- a/tests/operators/test_einsum_op.py +++ b/tests/operators/test_einsum_op.py @@ -4,8 +4,7 @@ import torch from mrpro.operators.EinsumOp import EinsumOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'complex128']) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 7ac1a8211..b7fd94576 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -6,8 +6,7 @@ from mrpro.data import SpatialDimension from mrpro.operators import FastFourierOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize(('npoints', 'a'), [(100, 20), (300, 20)]) @@ -28,7 +27,7 @@ def test_fast_fourier_op_forward(npoints, a): # Transform image to k-space ff_op = FastFourierOp(dim=(0,)) - (igauss_fwd,) = ff_op.forward(igauss) + (igauss_fwd,) = ff_op(igauss) # Scaling to "undo" fft scaling igauss_fwd *= np.sqrt(npoints) / 2 diff --git a/tests/operators/test_finite_difference_op.py b/tests/operators/test_finite_difference_op.py index f79da441b..ea21ae919 100644 --- a/tests/operators/test_finite_difference_op.py +++ b/tests/operators/test_finite_difference_op.py @@ -5,8 +5,7 @@ from einops import repeat from mrpro.operators import FiniteDifferenceOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('mode', ['central', 'forward', 'backward']) diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 89a4bbc11..c7c58c266 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,34 +1,41 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp -from tests import RandomGenerator +from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import COMMON_MR_TRAJECTORIES, create_traj -from tests.helper import dotproduct_adjointness_test -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @COMMON_MR_TRAJECTORIES -def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_fourier_op_fwd_adj_property( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test adjoint property of Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -41,30 +48,72 @@ def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, dotproduct_adjointness_test(fourier_op, u, v) +@COMMON_MR_TRAJECTORIES +def test_fourier_op_gram(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): + """Test gram of Fourier operator.""" + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + + recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + (expected,) = (fourier_op.H @ fourier_op)(img) + (actual,) = fourier_op.gram(img) + + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ - # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), # Cartesian ky dimension defined along k2 rather than k1 - (5, 1, 18, 64), - 'nuf', - 'uf', - 'nuf', + ( # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 96, 1, 1), # nky - Cartesian ky dimension defined along k2 rather than k1 + (5, 1, 18, 64), # nkz + 'non-uniform', # type_kx + 'uniform', # type_ky + 'non-uniform', # type_kz ), ], + ids=['cartesian_fft_dims_not_aligned_with_k2_k1_k0_dims'], ) -def test_fourier_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Test trajectory not supported by Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = RandomGenerator(13).randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted) diff --git a/tests/operators/test_grid_sampling_op.py b/tests/operators/test_grid_sampling_op.py index ed020956a..c8d2ccaf0 100644 --- a/tests/operators/test_grid_sampling_op.py +++ b/tests/operators/test_grid_sampling_op.py @@ -8,8 +8,7 @@ from mrpro.operators import GridSamplingOp from torch.autograd.gradcheck import gradcheck -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'float64', 'complex64']) diff --git a/tests/operators/test_identity_op.py b/tests/operators/test_identity_op.py index fa0ea600a..1855dc67b 100644 --- a/tests/operators/test_identity_op.py +++ b/tests/operators/test_identity_op.py @@ -1,10 +1,9 @@ """Tests for Identity Linear Operator and MultiIdentity Operator.""" -from typing import assert_type - import torch from mrpro.operators import IdentityOp, MagnitudeOp, MultiIdentityOp from mrpro.operators.LinearOperator import LinearOperator +from typing_extensions import assert_type from tests import RandomGenerator diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py new file mode 100644 index 000000000..2a669d799 --- /dev/null +++ b/tests/operators/test_linearoperatormatrix.py @@ -0,0 +1,321 @@ +from typing import Any + +import pytest +import torch +from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +from tests import RandomGenerator, dotproduct_adjointness_test + + +def random_linearop(size, rng): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex64_tensor(size), '... i j, ... j -> ... i') + + +def random_linearoperatormatrix(size, inner_size, rng): + """Create a random LinearOperatorMatrix.""" + operators = [[random_linearop(inner_size, rng) for i in range(size[1])] for j in range(size[0])] + return LinearOperatorMatrix(operators) + + +def test_linearoperatormatrix_shape(): + """Test creation and shape of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert matrix.shape == (5, 3) + + +def test_linearoperatormatrix_add_matrix(): + """Test addition of two LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 + matrix2)(*vector) + expected = tuple(a + b for a, b in zip(matrix1(*vector), matrix2(*vector), strict=False)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_add_tensor_nonsquare(): + """Test failure of addition of tensor to non-square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + with pytest.raises(NotImplementedError, match='square'): + (matrix1 + other) + + +def test_linearoperatormatrix_add_tensor_square(): + """Add tensor to square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 3), (2, 2), rng) + other = rng.complex64_tensor(2) + vector = rng.complex64_tensor((3, 2)) + result = (matrix1 + other)(*vector) + expected = tuple((mv + other * v for mv, v in zip(matrix1(*vector), vector, strict=True))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_rmul(): + """Test post multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + vector = rng.complex64_tensor((3, 10)) + result = (other * matrix1)(*vector) + expected = tuple(other * el for el in matrix1(*vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_mul(): + """Test pre multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(10) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 * other)(*vector) + expected = matrix1(*(other * el for el in vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition(): + """Test composition of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 @ matrix2)(*vector) + expected = matrix1(*(matrix2(*vector))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition_mismatch(): + """Test composition with mismatching shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((4, 3), (3, 10), rng) + vector = rng.complex64_tensor((4, 10)) + with pytest.raises(ValueError, match='shapes do not match'): + (matrix1 @ matrix2)(*vector) + + +def test_linearoperatormatrix_adjoint(): + """Test adjointness of Adjoint.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + dotproduct_adjointness_test(Wrapper(), rng.complex64_tensor((3, 10)), rng.complex64_tensor((5, 3))) + + +def test_linearoperatormatrix_repr(): + """Test repr of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert 'LinearOperatorMatrix(shape=(5, 3)' in repr(matrix) + + +def test_linearoperatormatrix_getitem(): + """Test slicing of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + def check(actual, expected): + assert tuple(tuple(row) for row in actual) == tuple(tuple(row) for row in expected) + + sliced = matrix[1:3, 2] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[2:3] for row in matrix._operators[1:3]]) + + sliced = matrix[0] + assert sliced.shape == (1, 6) + check(sliced._operators, matrix._operators[:1]) + + sliced = matrix[..., 0] + assert sliced.shape == (12, 1) + check(sliced._operators, [row[:1] for row in matrix._operators]) + + sliced = matrix[1:6:2, (3, 4)] + assert sliced.shape == (3, 2) + check(sliced._operators, [[matrix._operators[i][j] for j in (3, 4)] for i in range(1, 6, 2)]) + + sliced = matrix[-2:-4:-1, -1] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[-1:] for row in matrix._operators[-2:-4:-1]]) + + sliced = matrix[5, 5] + assert isinstance(sliced, LinearOperator) + assert sliced == matrix._operators[5][5] + + +def test_linearoperatormatrix_getitem_error(): + """Test error when slicing with wrong indices.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + with pytest.raises(IndexError, match='Too many indices'): + matrix[1, 1, 1] + + with pytest.raises(IndexError, match='out of range'): + matrix[20] + with pytest.raises(IndexError, match='out of range'): + matrix[-20] + with pytest.raises(IndexError, match='out of range'): + matrix[1:100] + with pytest.raises(IndexError, match='out of range'): + matrix[(100, 1)] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., -20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 1:100] + with pytest.raises(IndexError, match='index type'): + matrix[..., 1.0] + + +def test_linearoperatormatrix_norm_rows(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((3, 1), (3, 10), rng) + vector = rng.complex64_tensor((1, 10)) + result = matrix.operator_norm(*vector) + expected = sum(row[0].operator_norm(vector[0], dim=None) ** 2 for row in matrix._operators) ** 0.5 + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_norm_cols(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((1, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = matrix.operator_norm(*vector) + expected = max(op.operator_norm(v, dim=None) for op, v in zip(matrix._operators[0], vector, strict=False)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parametrize('seed', [0, 1, 2, 3]) +def test_linearoperatormatrix_norm(seed): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(seed) + matrix = random_linearoperatormatrix((4, 2), (3, 10), rng) + vector = rng.complex64_tensor((2, 10)) + result = matrix.operator_norm(*vector) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + real = Wrapper().operator_norm(vector, dim=None) + + assert result >= real + + +def test_linearoperatormatrix_shorthand_vertical(): + """Test shorthand for vertical stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 10), rng) + op2 = random_linearop((4, 10), rng) + x1 = rng.complex64_tensor((10,)) + + matrix1 = op1 & op2 + assert matrix1.shape == (2, 1) + + actual = matrix1(x1) + expected = (*op1(x1), *op2(x1)) + torch.testing.assert_close(actual, expected) + + matrix2 = op2 & (matrix1 & op1) + assert matrix2.shape == (4, 1) + + matrix3 = matrix2 & matrix2 + assert matrix3.shape == (8, 1) + + actual = matrix3(x1) + expected = 2 * (*op2(x1), *matrix1(x1), *op1(x1)) + torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_shorthand_horizontal(): + """Test shorthand for horizontal stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 4), rng) + op2 = random_linearop((3, 2), rng) + x1 = rng.complex64_tensor((4,)) + x2 = rng.complex64_tensor((2,)) + x3 = rng.complex64_tensor((4,)) + x4 = rng.complex64_tensor((2,)) + + matrix1 = op1 | op2 + assert matrix1.shape == (1, 2) + + actual1 = matrix1(x1, x2) + expected1 = (op1(x1)[0] + op2(x2)[0],) + torch.testing.assert_close(actual1, expected1) + + matrix2 = op2 | (matrix1 | op1) + assert matrix2.shape == (1, 4) + + matrix3 = matrix2 | matrix2 + assert matrix3.shape == (1, 8) + + expected3 = (2 * (op2(x2)[0] + (matrix1(x3, x4)[0] + op1(x1)[0])),) + actual3 = matrix3(x2, x3, x4, x1, x2, x3, x4, x1) + torch.testing.assert_close(actual3, expected3) + + +def test_linearoperatormatrix_stacking_error(): + """Test error when stacking matrix operators with different shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 4), (3, 10), rng) + matrix2 = random_linearoperatormatrix((3, 2), (3, 10), rng) + matrix3 = random_linearoperatormatrix((2, 4), (3, 10), rng) + op = random_linearop((3, 10), rng) + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & matrix2 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | matrix3 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | op + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & op + + +def test_linearoperatormatrix_error_nonlinearop(): + """Test error if trying to create a LinearOperatorMatrix with non linear operator.""" + op: Any = [[MagnitudeOp()]] # Any is used to hide this error from mypy + with pytest.raises(ValueError, match='LinearOperator'): + LinearOperatorMatrix(op) + + +def test_linearoperatormatrix_error_inconsistent_shapes(): + """Test error if trying to create a LinearOperatorMatrix with inonsistent row lengths.""" + rng = RandomGenerator(0) + op = random_linearop((3, 4), rng) + with pytest.raises(ValueError, match='same length'): + LinearOperatorMatrix([[op, op], [op]]) + + +def test_linearoperatormatrix_from_diagonal(): + """Test creation of LinearOperatorMatrix from diagonal.""" + rng = RandomGenerator(0) + ops = [random_linearop((2, 4), rng) for _ in range(3)] + matrix = LinearOperatorMatrix.from_diagonal(*ops) + xs = rng.complex64_tensor((3, 4)) + actual = matrix(*xs) + expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) + torch.testing.assert_close(actual, expected) diff --git a/tests/operators/test_magnitude_op.py b/tests/operators/test_magnitude_op.py index 88fc28209..d4cab4974 100644 --- a/tests/operators/test_magnitude_op.py +++ b/tests/operators/test_magnitude_op.py @@ -3,15 +3,23 @@ import torch from mrpro.operators import MagnitudeOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test def test_magnitude_operator_forward(): """Test that MagnitudeOp returns abs of tensors.""" - rng = RandomGenerator(2) - a = rng.complex64_tensor((2, 3)) - b = rng.complex64_tensor((3, 10)) + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((2, 3)) + b = random_generator.complex64_tensor((3, 10)) magnitude_op = MagnitudeOp() magnitude_a, magnitude_b = magnitude_op(a, b) assert torch.allclose(magnitude_a, torch.abs(a)) assert torch.allclose(magnitude_b, torch.abs(b)) + + +def test_autodiff_magnitude_operator(): + """Test autodiff works for magnitude operator.""" + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((5, 9, 8)) + b = random_generator.complex64_tensor((10, 11, 12)) + autodiff_test(MagnitudeOp(), a, b) diff --git a/tests/operators/test_operator_norm.py b/tests/operators/test_operator_norm.py index d5ae39873..c6e13dcf9 100644 --- a/tests/operators/test_operator_norm.py +++ b/tests/operators/test_operator_norm.py @@ -12,9 +12,8 @@ def test_power_iteration_uses_stopping_criterion(): """Test if the power iteration stops if the absolute and relative tolerance are chosen high.""" - # callback function that should not be called because the power iteration - # should stop if the tolerances are set high - def callback(): + def callback(_): + """Callback function that should not be called, because the power iteration should stop.""" pytest.fail('The power iteration did not stop despite high atol and rtol!') random_generator = RandomGenerator(seed=0) diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index 82fcd113d..378060b74 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -1,13 +1,13 @@ """Tests for the operators module.""" -from typing import Any, assert_type, cast +from typing import cast import pytest import torch from mrpro.operators import LinearOperator, Operator +from typing_extensions import Any, assert_type -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test class DummyOperator(Operator[torch.Tensor, tuple[torch.Tensor,]]): diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..e73bf3951 --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,49 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator, dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) diff --git a/tests/operators/test_phase_op.py b/tests/operators/test_phase_op.py index aadfbdf24..726569312 100644 --- a/tests/operators/test_phase_op.py +++ b/tests/operators/test_phase_op.py @@ -3,15 +3,23 @@ import torch from mrpro.operators import PhaseOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test def test_phase_operator_forward(): """Test that PhaseOp returns angle of tensors.""" - rng = RandomGenerator(2) - a = rng.complex64_tensor((2, 3)) - b = rng.complex64_tensor((3, 10)) + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((2, 3)) + b = random_generator.complex64_tensor((3, 10)) phase_op = PhaseOp() phase_a, phase_b = phase_op(a, b) assert torch.allclose(phase_a, torch.angle(a)) assert torch.allclose(phase_b, torch.angle(b)) + + +def test_autodiff_magnitude_operator(): + """Test autodiff works for magnitude operator.""" + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((5, 9, 8)) + b = random_generator.complex64_tensor((10, 11, 12)) + autodiff_test(PhaseOp(), a, b) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 3db888897..ecacafb42 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -3,8 +3,7 @@ import pytest from mrpro.operators.RearrangeOp import RearrangeOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'complex128']) @@ -15,6 +14,7 @@ ((2, 2, 4), '... a b->... (a b)', (2, 8), {'b': 4}), # flatten ((2), '... (a b) -> ... a b', (2, 1), {'b': 1}), # unflatten ], + ids=['swap_axes', 'flatten', 'unflatten'], ) def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype): """Test adjointness and shape.""" diff --git a/tests/operators/test_sensitivity_op.py b/tests/operators/test_sensitivity_op.py index 25efc5bee..5576a9892 100644 --- a/tests/operators/test_sensitivity_op.py +++ b/tests/operators/test_sensitivity_op.py @@ -5,8 +5,7 @@ from mrpro.data import CsmData, QHeader, SpatialDimension from mrpro.operators import SensitivityOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_sensitivity_op_adjointness(): @@ -90,7 +89,7 @@ def test_sensitivity_op_other_dim_compatibility_fail(n_other_csm, n_other_img): # Apply to n_other_img shape u = random_generator.complex64_tensor(size=(n_other_img, 1, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): - sensitivity_op.forward(u) + sensitivity_op(u) v = random_generator.complex64_tensor(size=(n_other_img, n_coils, *n_zyx)) with pytest.raises(RuntimeError, match='The size of tensor'): diff --git a/tests/operators/test_slice_projection_op.py b/tests/operators/test_slice_projection_op.py index 09a296d4c..53209e98d 100644 --- a/tests/operators/test_slice_projection_op.py +++ b/tests/operators/test_slice_projection_op.py @@ -9,8 +9,7 @@ from mrpro.operators import SliceProjectionOp from mrpro.utils.slice_profiles import SliceGaussian, SliceInterpolate, SliceSmoothedRectangular -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_slice_projection_op_cube_basic(): diff --git a/tests/operators/test_wavelet_op.py b/tests/operators/test_wavelet_op.py index 46e90da46..92de01286 100644 --- a/tests/operators/test_wavelet_op.py +++ b/tests/operators/test_wavelet_op.py @@ -8,8 +8,7 @@ from ptwt.conv_transform_2 import wavedec2 from ptwt.conv_transform_3 import wavedec3 -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test, operator_isometry_test, operator_unitary_test +from tests import RandomGenerator, dotproduct_adjointness_test, linear_operator_unitary_test, operator_isometry_test @pytest.mark.parametrize( @@ -168,4 +167,4 @@ def test_wavelet_op_unitary(im_shape, domain_shape, dim, wavelet_name): random_generator = RandomGenerator(seed=0) img = random_generator.complex64_tensor(size=im_shape) wavelet_op = WaveletOp(domain_shape=domain_shape, dim=dim, wavelet_name=wavelet_name) - operator_unitary_test(wavelet_op, img) + linear_operator_unitary_test(wavelet_op, img) diff --git a/tests/operators/test_zero_op.py b/tests/operators/test_zero_op.py index 5778cd69d..1e7f47017 100644 --- a/tests/operators/test_zero_op.py +++ b/tests/operators/test_zero_op.py @@ -1,11 +1,9 @@ -from typing import assert_type - import torch from mrpro.operators import IdentityOp, LinearOperator, MagnitudeOp, Operator, ZeroOp from mrpro.operators.LinearOperator import LinearOperatorSum +from typing_extensions import assert_type -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_zero_op_keepshape(): diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index a56635f5f..5d1f02135 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -4,8 +4,7 @@ import torch from mrpro.operators import ZeroPadOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_zero_pad_op_content(): @@ -20,7 +19,7 @@ def test_zero_pad_op_content(): original_shape=tuple([original_shape[d] for d in padding_dimensions]), padded_shape=tuple([padded_shape[d] for d in padding_dimensions]), ) - (padded_data,) = zero_padding_op.forward(original_data) + (padded_data,) = zero_padding_op(original_data) # Compare overlapping region torch.testing.assert_close(original_data[:, 10:90, :, 50:150, :, :], padded_data[:, :, :, :, 95:145, :]) diff --git a/tests/phantoms/test_ellipse_phantom.py b/tests/phantoms/test_ellipse_phantom.py index 2b34b06b5..c7ab58850 100644 --- a/tests/phantoms/test_ellipse_phantom.py +++ b/tests/phantoms/test_ellipse_phantom.py @@ -5,7 +5,7 @@ from mrpro.data import SpatialDimension from mrpro.operators import FastFourierOp -from tests.helper import relative_image_difference +from tests import relative_image_difference def test_image_space(ellipse_phantom): diff --git a/tests/utils/test_fill_range.py b/tests/utils/test_fill_range.py new file mode 100644 index 000000000..d6dd700c4 --- /dev/null +++ b/tests/utils/test_fill_range.py @@ -0,0 +1,23 @@ +"""Tests for fill_range_""" + +import pytest +import torch +from mrpro.utils import fill_range_ + + +@pytest.mark.parametrize('dtype', [torch.float32, torch.int64], ids=['float32', 'int64']) +def test_fill_range(dtype): + """Test functionality of fill_range.""" + tensor = torch.zeros(3, 4, dtype=dtype) + fill_range_(tensor, dim=1) + expected = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]], dtype=tensor.dtype) + torch.testing.assert_close(tensor, expected) + + +def test_fill_range_dim_out_of_range(): + """Test fill_range_ with a dimension out of range.""" + tensor = torch.zeros(3, 4) + with pytest.raises(IndexError, match='Dimension 2 is out of range'): + fill_range_(tensor, dim=2) + with pytest.raises(IndexError, match='Dimension -3 is out of range'): + fill_range_(tensor, dim=-3) diff --git a/tests/utils/test_modify_acq_info.py b/tests/utils/test_modify_acq_info.py deleted file mode 100644 index 451303d02..000000000 --- a/tests/utils/test_modify_acq_info.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for modification of acquisition infos.""" - -from einops import rearrange -from mrpro.utils import modify_acq_info - - -def test_modify_acq_info(random_kheader_shape): - """Test the modification of the acquisition info.""" - # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] - kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape - - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) - - # Verify shape - assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py new file mode 100644 index 000000000..dd57b8feb --- /dev/null +++ b/tests/utils/test_reshape.py @@ -0,0 +1,53 @@ +"""Tests for reshaping utilities.""" + +import torch +from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right + +from tests import RandomGenerator + + +def test_broadcast_right(): + """Test broadcast_right""" + tensors = (torch.ones(1, 2, 3), torch.ones(1, 2), torch.ones(2)) + broadcasted = broadcast_right(*tensors) + assert broadcasted[0].shape == broadcasted[1].shape == broadcasted[2].shape == (2, 2, 3) + + +def test_unsqueeze_left(): + """Test unsqueeze_left""" + tensor = torch.ones(1, 2, 3) + unsqueezed = unsqueeze_left(tensor, 2) + assert unsqueezed.shape == (1, 1, 1, 2, 3) + assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_unsqueeze_right(): + """Test unsqueeze_right""" + tensor = torch.ones(1, 2, 3) + unsqueezed = unsqueeze_right(tensor, 2) + assert unsqueezed.shape == (1, 2, 3, 1, 1) + assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_reduce_view(): + """Test reduce_view""" + + tensor = RandomGenerator(0).float32_tensor((1, 2, 3, 1, 1, 1)) + tensor = tensor.expand(1, 2, 3, 4, 1, 1).contiguous() # this cannot be removed + tensor = tensor.expand(7, 2, 3, 4, 5, 6) + + reduced_all = reduce_view(tensor) + assert reduced_all.shape == (1, 2, 3, 4, 1, 1) + assert torch.equal(reduced_all.expand_as(tensor), tensor) + + reduced_two = reduce_view(tensor, (0, -1)) + assert reduced_two.shape == (1, 2, 3, 4, 5, 1) + assert torch.equal(reduced_two.expand_as(tensor), tensor) + + reduced_one_neg = reduce_view(tensor, -1) + assert reduced_one_neg.shape == (7, 2, 3, 4, 5, 1) + assert torch.equal(reduced_one_neg.expand_as(tensor), tensor) + + reduced_one_pos = reduce_view(tensor, 0) + assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) + assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) diff --git a/tests/utils/test_split_idx.py b/tests/utils/test_split_idx.py index 6997fc1bc..30501b9ac 100644 --- a/tests/utils/test_split_idx.py +++ b/tests/utils/test_split_idx.py @@ -5,6 +5,8 @@ from einops import repeat from mrpro.utils import split_idx +from tests import RandomGenerator + @pytest.mark.parametrize( ('ni_per_block', 'ni_overlap', 'cyclic', 'unique_values_in_last_block'), @@ -19,7 +21,7 @@ def test_split_idx(ni_per_block, ni_overlap, cyclic, unique_values_in_last_block # Create a regular sequence of values vals = repeat(torch.tensor([0, 1, 2, 3]), 'd0 -> (d0 repeat)', repeat=5) # Mix up values - vals = vals[torch.randperm(vals.shape[0])] + vals = vals[RandomGenerator(13).randperm(vals.shape[0])] # Split indices of sorted sequence idx_split = split_idx(torch.argsort(vals), ni_per_block, ni_overlap, cyclic) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py new file mode 100644 index 000000000..a232a5366 --- /dev/null +++ b/tests/utils/test_unit_conversion.py @@ -0,0 +1,82 @@ +"""Tests of unit conversion.""" + +import numpy as np +import torch +from mrpro.utils.unit_conversion import ( + deg_to_rad, + lamor_frequency_to_magnetic_field, + m_to_mm, + magnetic_field_to_lamor_frequency, + mm_to_m, + ms_to_s, + rad_to_deg, + s_to_ms, +) + +from tests import RandomGenerator + + +def test_mm_to_m(): + """Verify mm to m conversion.""" + generator = RandomGenerator(seed=0) + mm_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(mm_to_m(mm_input), mm_input / 1000.0) + + +def test_m_to_mm(): + """Verify m to mm conversion.""" + generator = RandomGenerator(seed=0) + m_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(m_to_mm(m_input), m_input * 1000.0) + + +def test_ms_to_s(): + """Verify ms to s conversion.""" + generator = RandomGenerator(seed=0) + ms_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(ms_to_s(ms_input), ms_input / 1000.0) + + +def test_s_to_ms(): + """Verify s to ms conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(s_to_ms(s_input), s_input * 1000.0) + + +def test_rad_to_deg_tensor(): + """Verify radians to degree conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(rad_to_deg(s_input), torch.rad2deg(s_input)) + + +def test_deg_to_rad_tensor(): + """Verify degree to radians conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(deg_to_rad(s_input), torch.deg2rad(s_input)) + + +def test_rad_to_deg_float(): + """Verify radians to degree conversion.""" + assert rad_to_deg(np.pi / 2) == 90.0 + + +def test_deg_to_rad_float(): + """Verify degree to radians conversion.""" + assert deg_to_rad(180.0) == np.pi + + +def test_lamor_frequency_to_magnetic_field(): + """Verify conversion of lamor frequency to magnetic field.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + proton_lamor_frequency_at_3tesla = 127.74 * 1e6 + assert lamor_frequency_to_magnetic_field(proton_lamor_frequency_at_3tesla, proton_gyromagnetic_ratio) == 3.0 + + +def test_magnetic_field_to_lamor_frequency(): + """Verify conversion of magnetic field to lamor frequency.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + magnetic_field_strength = 3.0 + assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6