diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 27bef5c9..0cd91e4d 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -10,6 +10,8 @@ on: env: PYTHON_VERSION: "3.10" + BART_VERSION: "0.8.00" + ref_backend: "finufft" jobs: linter-check: @@ -21,12 +23,18 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - - name: Black setup + cache: pip + + - name: Install Python deps shell: bash - run: pip install black ruff + run: | + python -m pip install --upgrade pip + python -m pip install -e .[test,dev] + - name: Black Check shell: bash run: black . --diff --color --check + - name: ruff Check shell: bash run: ruff src @@ -34,6 +42,12 @@ jobs: test-cpu: runs-on: ubuntu-latest needs: linter-check + strategy: + matrix: + backend: [finufft, pynfft, pynufft, bart, sigpy] + exclude: + - backend: bart + - backend: pynfft steps: - uses: actions/checkout@v3 @@ -41,69 +55,86 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} + cache: pip + - name: Install Dependencies shell: bash run: | - sudo apt install -y libnfft3-dev python --version + + - name: Install Python Deps + shell: bash + run: | python -m pip install --upgrade pip python -m pip install -e .[test] - - name: Install CPU Backends + - name: Install pynfft + if: ${{ matrix.backend == 'pynfft' || env.ref_backend == 'pynfft' }} shell: bash run: | - python -m pip install finufft pynfft2 "cython<3.0.0" + sudo apt install -y libnfft3-dev + python -m pip install pynfft2 "cython<3.0.0" - - name: Run Tests + - name: Install pynufft + if: ${{ matrix.backend == 'pynufft' || env.ref_backend == 'pynufft' }} + run: python -m pip install pynufft + + - name: Install finufft + if: ${{ matrix.backend == 'finufft' || env.ref_backend == 'finufft'}} shell: bash - run: | - coverage run -m pytest -n auto -v - coverage report + run: python -m pip install finufft + + - name: Install Sigpy + if: ${{ matrix.backend == 'sigpy' || env.ref_backend == 'sigpy'}} + shell: bash + run: python -m pip install sigpy + + + - name: Install BART + if: ${{ matrix.backend == 'bart' || env.ref_backend == 'bart'}} + shell: bash + run: | + cd $RUNNER_WORKSPACE + sudo apt-get install make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev + wget https://github.com/mrirecon/bart/archive/v${{ env.BART_VERSION }}.tar.gz + tar xzvf v${{ env.BART_VERSION }}.tar.gz + cd bart-${{ env.BART_VERSION }} + make + echo $PWD >> $GITHUB_PATH + - name: Run Tests + shell: bash + run: | + export COVERAGE_FILE=coverage_${{ matrix.backend }} + pytest --backend ${{ matrix.backend }} --ref ${{ env.ref_backend }} --cov --disable-pytest-warnings --cov-branch --cov-report=term - name: Upload coverage uses: actions/upload-artifact@v3 with: - name: coverage-cpu-${{ github.sha }} - path: .coverage - - #https://stackoverflow.com/a/74411469/16019838 - # - - name: Abort other jobs - if: failure() - uses: actions/github-script@v6 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GITHUB_RUN_ID: ${{ github.run_id }} - with: - script: | - const RUN_ID = process.env.GITHUB_RUN_ID - const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/"); - const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', { - owner: OWNER, - repo: REPO, - run_id: RUN_ID, - headers: {'X-GitHub-Api-Version': '2022-11-28'}}) - + name: coverage_data + path: coverage_${{ matrix.backend}} test-gpu: runs-on: GPU needs: linter-check strategy: - fail-fast: true + matrix: + backend: [gpunufft, cufinufft] steps: - uses: actions/checkout@v3 - - name: Install mri-nufft + - name: Install mri-nufft and finufft shell: bash run: | cd $RUNNER_WORKSPACE python --version python -m venv venv source $RUNNER_WORKSPACE/venv/bin/activate - pip install --upgrade pip wheel numpy + pip install --upgrade pip wheel pip install -e mri-nufft[test] + pip install finufft - name: Install Cufinufft + if: ${{ matrix.backend == 'cufinufft' }} shell: bash run: | cd $RUNNER_WORKSPACE @@ -127,6 +158,7 @@ jobs: cd $RUNNER_WORKSPACE - name: Install gpuNUFFT + if: ${{ matrix.backend == 'gpunufft' }} shell: bash run: | cd $RUNNER_WORKSPACE @@ -147,34 +179,15 @@ jobs: run: | cd $RUNNER_WORKSPACE/mri-nufft source $RUNNER_WORKSPACE/venv/bin/activate - python -m coverage run -m pytest -n auto -v - coverage report - + export COVERAGE_FILE=coverage_${{ matrix.backend }} + python -m pytest --ref ${{ env.ref_backend }} --backend ${{ matrix.backend }} --disable-pytest-warnings --cov --cov-branch --cov-report=term - name: Upload coverage if: success() uses: actions/upload-artifact@v3 with: - name: coverage-gpu-${{ github.sha }} - path: .coverage - - #https://stackoverflow.com/a/74411469/16019838 - # - - name: Abort other jobs - if: failure() - uses: actions/github-script@v6 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GITHUB_RUN_ID: ${{ github.run_id }} - with: - script: | - const RUN_ID = process.env.GITHUB_RUN_ID - const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/"); - const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', { - owner: OWNER, - repo: REPO, - run_id: RUN_ID, - headers: {'X-GitHub-Api-Version': '2022-11-28'}}) + name: coverage_data + path: coverage_${{ matrix.backend }} - name: Cleanup if: always() @@ -186,44 +199,66 @@ jobs: rm -rf gpuNUFFT rm -rf venv + test-examples: + runs-on: ubuntu-latest + needs: linter-check + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Install Python deps + shell: bash + run: | + python -m pip install --upgrade pip + python -m pip install -e .[test,dev] + python -m pip install finufft pooch brainweb-dl + + - name: Run examples + shell: bash + run: | + export COVERAGE_FILE=coverage_plots + pytest examples --cov --cov-branch --cov-report=term + + - name: Upload coverage + if: success() + uses: actions/upload-artifact@v3 + with: + name: coverage_data + path: coverage_plots + coverage: runs-on: ubuntu-latest - needs: [test-cpu, test-gpu] + needs: [test-cpu, test-gpu, test-examples] + if: ${{ always() }} steps: - name: Checkout uses: actions/checkout@v3 - - name: Download coverage CPU - uses: actions/download-artifact@v3 - with: - name: coverage-cpu-${{ github.sha }} - path: cov-cpu - - name: Download coverage GPU - uses: actions/download-artifact@v3 - with: - name: coverage-gpu-${{ github.sha }} - path: cov-gpu + + - name: Collect Coverages + uses: actions/download-artifact@v2 - name: Set up Python 3.10 uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} + cache: pip + - name: add the coverage tool shell: bash run: | python -m pip install --upgrade pip python -m pip install coverage[toml] python -m pip install -e . + - name: Combine coverage - shell: bash + run: coverage combine -a coverage_data/* + + - name: Reports run: | - coverage combine cov-cpu/.coverage cov-gpu/.coverage - coverage xml -i + coverage xml coverage report - ls -al - - - name: Upload Join - if: success() - uses: actions/upload-artifact@v3 - with: - name: coverage-full-${{ github.sha}} - path: .coverage diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/conftest.py b/examples/conftest.py new file mode 100644 index 00000000..4f9ad711 --- /dev/null +++ b/examples/conftest.py @@ -0,0 +1,52 @@ +"""TEST CONFIGURATION. + +This module contains methods for configuring the testing of the example +scripts. + +:Author: Pierre-Antoine Comby + +Notes +----- +Based on: +https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test + +""" +from pathlib import Path +import runpy +import pytest +import matplotlib as mpl + + +mpl.use("agg") + + +def pytest_collect_file(path, parent): + """Pytest hook. + + Create a collector for the given path, or None if not relevant. + The new node needs to have the specified parent as parent. + """ + p = Path(path) + if p.suffix == ".py" and "example" in p.name: + return Script.from_parent(parent, path=p, name=p.name) + + +class Script(pytest.File): + """Script files collected by pytest.""" + + def collect(self): + """Collect the script as its own item.""" + yield ScriptItem.from_parent(self, name=self.name) + + +class ScriptItem(pytest.Item): + """Item script collected by pytest.""" + + def runtest(self): + """Run the script as a test.""" + runpy.run_path(str(self.path)) + + def repr_failure(self, excinfo): + """Return only the error traceback of the script.""" + excinfo.traceback = excinfo.traceback.cut(path=self.path) + return super().repr_failure(excinfo) diff --git a/pyproject.toml b/pyproject.toml index 7b463547..77c60d4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dynamic = ["version"] [project.optional-dependencies] gpu = ["cupy-wheel"] -test = ["pytest", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases", "pynfft2"] +test = ["pytest", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] doc = ["sphinx-book-theme","sphinx-copybutton", "sphinx-gallery", "matplotlib", "pooch", "brainweb-dl"] @@ -57,6 +57,7 @@ profile="black" [tool.pytest.ini_options] minversion = "6.0" norecursedirs = ["tests/helpers"] +testpaths=["tests"] [tool.pylsp-mypy] enabled = false diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 5860c295..272c1fc9 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -6,7 +6,7 @@ :author: Pierre-Antoine Comby """ from abc import ABC, abstractmethod - +from functools import partial import warnings import numpy as np @@ -98,14 +98,21 @@ class or instance of class if args or kwargs are given. ------ ValueError if the backend is not available. """ + available = True try: available, operator = FourierOperatorBase.interfaces[backend_name] except KeyError as exc: - raise ValueError(f"backend {backend_name} is not available") from exc + if not backend_name.startswith("stacked-"): + raise ValueError(f"backend {backend_name} does not exist") from exc + # try to get the backend with stacked + # Dedicated registered stacked backend (like stacked-cufinufft) + # have be found earlier. + backend = backend_name.split("-")[1] + operator = get_operator("stacked") + operator = partial(operator, backend=backend) + if not available: - raise ValueError( - f"backend {backend_name} is registered, but dependencies are not met." - ) + raise ValueError(f"backend {backend_name} found, but dependencies are not met.") if args or kwargs: operator = operator(*args, **kwargs) @@ -170,6 +177,14 @@ def adj_op(self, coeffs): """ pass + def data_consistency(self, image, obs_data): + """Compute the gradient data consistency. + + This is the naive implementation using adj_op(op(x)-y). + Specific backend can (and should!) implement a more efficient version. + """ + return self.adj_op(self.op(image) - obs_data) + def with_off_resonnance_correction(self, B, C, indices): """Return a new operator with Off Resonnance Correction.""" from ..off_resonnance import MRIFourierCorrected @@ -290,6 +305,8 @@ def __repr__(self): class FourierOperatorCPU(FourierOperatorBase): """Base class for CPU-based NUFFT operator. + The NUFFT operation will be done sequentially and looped over coils and batches. + Parameters ---------- samples: np.ndarray @@ -315,20 +332,25 @@ def __init__( shape, density=False, n_coils=1, + n_batchs=1, + n_trans=1, smaps=None, raw_op=None, + squeeze_dims=True, ): super().__init__() + self.shape = shape - self.samples = proper_trajectory(samples, normalize="unit") - self._dtype = self.samples.dtype - self._uses_sense = False + + # we will access the samples by their coordinate first. + self.samples = samples.reshape(-1, len(shape)) + self.dtype = self.samples.dtype # Density Compensation Setup if density is True: - self.density = self.estimate_density(samples, shape) + self.density = self.estimate_density(self.samples, shape) elif isinstance(density, np.ndarray): - if len(density) != len(samples): + if len(density) != len(self.samples): raise ValueError( "Density array and samples array should have the same length." ) @@ -339,15 +361,10 @@ def __init__( if n_coils < 1: raise ValueError("n_coils should be ≥ 1") self.n_coils = n_coils - if smaps is not None: - self._uses_sense = True - if isinstance(smaps, np.ndarray): - raise ValueError("Smaps should be either a C-ordered ndarray") - self._smaps = smaps - else: - self._uses_sense = False - - # Raw_op should be instantiated by subclasses. + self.smaps = smaps + self.n_batchs = n_batchs + self.n_trans = n_trans + self.squeeze_dims = squeeze_dims self.raw_op = raw_op @@ -379,29 +396,41 @@ def op(self, data, ksp=None): ret = self._op_sense(data, ksp) # calibrationless or monocoil. else: - if data.ndim == self.ndim: - data = np.expand_dims(data, axis=0) # add coil dimension ret = self._op_calibless(data, ksp) - return ret + ret /= self.norm_factor + return self._safe_squeeze(ret) - def _op_sense(self, data, ksp_d=None): - coil_img = np.empty((self.n_coils, *self.shape), dtype=data.dtype) - ksp = np.zeros((self.n_coils, self.n_samples), dtype=data.dtype) - coil_img = data * self._smaps - self._op(coil_img) + def _op_sense(self, data, ksp=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + dataf = data.reshape((B, *XYZ)) + if ksp is None: + ksp = np.empty((B * C, K), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) + coil_img *= dataf[idx_batch] + self._op(coil_img, ksp[i * T : (i + 1) * T]) + ksp = ksp.reshape((B, C, K)) return ksp def _op_calibless(self, data, ksp=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape if ksp is None: - ksp = np.empty((self.n_coils, self.n_samples), dtype=data.dtype) - for i in range(self.n_coils): - self._op(data[i], ksp[i]) + ksp = np.empty((B * C, K), dtype=self.cpx_dtype) + dataf = np.reshape(data, (B * C, *XYZ)) + for i in range((B * C) // T): + self._op( + dataf[i * T : (i + 1) * T], + ksp[i * T : (i + 1) * T], + ) + ksp = ksp.reshape((B, C, K)) return ksp def _op(self, image, coeffs): self.raw_op.op(coeffs, image) - coeffs /= self.norm_factor - return coeffs def adj_op(self, coeffs, img=None): """Non Cartesian MRI adjoint operator. @@ -424,34 +453,49 @@ def adj_op(self, coeffs, img=None): # calibrationless or monocoil. else: ret = self._adj_op_calibless(coeffs, img) - return ret + ret /= self.norm_factor + return self._safe_squeeze(ret) def _adj_op_sense(self, coeffs, img=None): - coil_img = np.empty(self.shape, dtype=coeffs.dtype) + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape if img is None: - img = np.zeros(self.shape, dtype=coeffs.dtype) - self._adj_op(coeffs, coil_img) - img = np.sum(coil_img * self._smaps.conjugate(), axis=0) + img = np.zeros((B, *XYZ), dtype=self.cpx_dtype) + coeffs_flat = coeffs.reshape((B * C, K)) + img_batched = np.zeros((T, *XYZ), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + self._adj_op(coeffs_flat[i * T : (i + 1) * T], img_batched) + img_batched *= self.smaps[idx_coils].conj() + for t, b in enumerate(idx_batch): + img[b] += img_batched[t] + img = img.reshape((B, 1, *XYZ)) return img def _adj_op_calibless(self, coeffs, img=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape if img is None: - img = np.zeros((self.n_coils, *self.shape), dtype=coeffs.dtype) - self._adj_op(coeffs, img) - return img + img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) + coeffs_f = np.reshape(coeffs, (B * C, K)) + for i in range((B * C) // T): + self._adj_op(coeffs_f[i * T : (i + 1) * T], img[i * T : (i + 1) * T]) - def _apply_dc(self, coeffs): - if self.density is not None: - return coeffs * self.density - return coeffs + img = img.reshape((B, C, *XYZ)) + return img def _adj_op(self, coeffs, image): - self.raw_op.adj_op(self._apply_dc(coeffs), image) - image /= self.norm_factor - return image + if self.density is not None: + coeffs2 = coeffs.copy() + for i in range(self.n_trans): + coeffs2[i * self.n_samples : (i + 1) * self.n_samples] *= self.density + else: + coeffs2 = coeffs + self.raw_op.adj_op(coeffs2, image) def data_consistency(self, image_data, obs_data): - """Compute the gradient estimation directly on gpu. + """Compute the gradient data consistency. This mixes the op and adj_op method to perform F_adj(F(x-y)) on a per coil basis. By doing the computation coil wise, @@ -466,31 +510,61 @@ def data_consistency(self, image_data, obs_data): Observed data. """ if self.uses_sense: - return self._data_consistency_sense(image_data, obs_data) - return self._data_consistency_calibless(image_data, obs_data) - - def _data_consistency_sense(self, image_data, obs_data): - img = np.empty_like(image_data) - coil_img = np.empty(self.shape, dtype=image_data.dtype) - coil_ksp = np.empty(self.n_samples, dtype=obs_data.dtype) - for i in range(self.n_coils): - np.copyto(coil_img, img) - coil_img *= self._smap + return self._safe_squeeze(self._grad_sense(image_data, obs_data)) + return self._safe_squeeze(self._grad_calibless(image_data, obs_data)) + + def _grad_sense(self, image_data, obs_data): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + dataf = image_data.reshape((B, *XYZ)) + obs_dataf = obs_data.reshape((B * C, K)) + grad = np.empty_like(dataf) + + coil_img = np.empty((T, *XYZ), dtype=self.cpx_dtype) + coil_ksp = np.empty((T, K), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) + coil_img *= dataf[idx_batch] self._op(coil_img, coil_ksp) - coil_ksp -= obs_data[i] - if self.uses_density: - coil_ksp *= self.density_d + coil_ksp /= self.norm_factor + coil_ksp -= obs_dataf[i * T : (i + 1) * T] self._adj_op(coil_ksp, coil_img) - img += coil_img * self._smaps[i].conjugate() - return img - - def _data_consistency_calibless(self, image_data, obs_data): - img = np.empty((self.n_coils, *self.shape), dtype=image_data.dtype) - ksp = np.empty(self.n_samples, dtype=obs_data.dtype) - for i in range(self.n_coils): - self._op(image_data[i], ksp) - ksp -= obs_data[i] + coil_img *= self.smaps[idx_coils].conj() + for t, b in enumerate(idx_batch): + grad[b] += coil_img[t] + grad /= self.norm_factor + return grad + + def _grad_calibless(self, image_data, obs_data): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + dataf = image_data.reshape((B * C, *XYZ)) + obs_dataf = obs_data.reshape((B * C, K)) + grad = np.empty_like(dataf) + ksp = np.empty((T, K), dtype=self.cpx_dtype) + for i in range(B * C // T): + self._op(dataf[i * T : (i + 1) * T], ksp) + ksp /= self.norm_factor + ksp -= obs_dataf[i * T : (i + 1) * T] if self.uses_density: - ksp *= self.density_d - self._adj_op(ksp, img[i]) - return img + ksp *= self.density + self._adj_op(ksp, grad[i * T : (i + 1) * T]) + grad /= self.norm_factor + return grad.reshape(B, C, *XYZ) + + def _safe_squeeze(self, arr): + """Squeeze the first two dimensions of shape of the operator.""" + if self.squeeze_dims: + try: + arr = arr.squeeze(axis=1) + except ValueError: + pass + try: + arr = arr.squeeze(axis=0) + except ValueError: + pass + return arr diff --git a/src/mrinufft/operators/interfaces/bart.py b/src/mrinufft/operators/interfaces/bart.py new file mode 100644 index 00000000..79989d12 --- /dev/null +++ b/src/mrinufft/operators/interfaces/bart.py @@ -0,0 +1,285 @@ +"""Interface for the BART NUFFT. + +BART uses a command line interfaces, and read/writes data to files. + +The file format is described here: https://bart-doc.readthedocs.io/en/latest/data.html#non-cartesian-datasets + +""" + +import warnings +import os +import numpy as np +import mmap +import subprocess as subp +import tempfile +from pathlib import Path + +from ..base import FourierOperatorCPU, proper_trajectory + +# available if return code is 0 +BART_AVAILABLE = not subp.call( + ["which", "bart"], stdout=subp.DEVNULL, stderr=subp.DEVNULL +) + + +class RawBartNUFFT: + """Wrapper around BART NUFFT CLI.""" + + def __init__(self, samples, shape, extra_op_args=None, extra_adj_op_args=None): + self.samples = samples # To normalize and send to file + self.shape = shape + self.shape_str = ":".join([str(s) for s in shape]) + self.shape_str += ":1" if len(shape) == 2 else "" + self._op_args = extra_op_args or [] + self._adj_op_args = extra_adj_op_args or [] + self._temp_dir = tempfile.TemporaryDirectory() + + # Write trajectory to temp file + tmp_path = Path(self._temp_dir.name) + self._traj_file = tmp_path / "traj" + self._ksp_file = tmp_path / "ksp" + self._grid_file = tmp_path / "grid" + + traj2cfl(self.samples, self.shape, self._traj_file) + + def _tmp_file(self): + """Return a temporary file name.""" + return os.path.join(self._temp_dir.name, next(tempfile._get_candidate_names())) + + def __del__(self): + """Delete also the temporary files.""" + self._temp_dir.cleanup() + + def op(self, coeffs_data, grid_data): + """Forward Operator.""" + grid_data_ = grid_data.reshape(self.shape) + _writecfl(grid_data_, self._grid_file) + cmd = [ + "bart", + "nufft", + "-d", + self.shape_str, + *self._op_args, + str(self._traj_file), + str(self._grid_file), + str(self._ksp_file), + ] + try: + subp.run(cmd, check=True, capture_output=True) + except subp.CalledProcessError as exc: + msg = "Failed to run BART NUFFT\n" + msg += f"error code: {exc.returncode}\n" + msg += "cmd: " + " ".join(cmd) + "\n" + msg += f"stdout: {exc.output}\n" + msg += f"stderr: {exc.stderr}" + raise RuntimeError(msg) from exc + + ksp_raw = _readcfl(self._ksp_file) + np.copyto(coeffs_data, ksp_raw) + return coeffs_data + + def adj_op(self, coeffs_data, grid_data): + """Adjoint Operator.""" + # Format grid data to cfl format, and write to file + # Run bart nufft with argument in subprocess + + coeffs_ = coeffs_data.reshape(len(self.samples)) + _writecfl(coeffs_[None, ..., None, None, None], self._ksp_file) + + cmd = [ + "bart", + "nufft", + "-d", + self.shape_str, + "-a" if "-i" not in self._adj_op_args else "", + *self._adj_op_args, + str(self._traj_file), + str(self._ksp_file), + str(self._grid_file), + ] + try: + subp.run(cmd, check=True, capture_output=True) + except subp.CalledProcessError as exc: + msg = "Failed to run BART NUFFT\n" + msg += f"error code: {exc.returncode}\n" + msg += "cmd: " + " ".join(cmd) + "\n" + msg += f"stdout: {exc.output}\n" + msg += f"stderr: {exc.stderr}" + raise RuntimeError(msg) from exc + + grid_raw = _readcfl(self._grid_file) + np.copyto(grid_data, grid_raw) + return grid_data + + +class MRIBartNUFFT(FourierOperatorCPU): + """BART implementation of MRI NUFFT transform.""" + + # TODO override Data consistency function: use toepliz + + backend = "bart" + available = BART_AVAILABLE + + def __init__( + self, + samples, + shape, + density=False, + n_coils=1, + n_batchs=1, + smaps=None, + squeeze_dims=True, + **kwargs, + ): + samples_ = proper_trajectory(samples, normalize="unit") + if density is True: + density = False + if getattr(kwargs, "extra_adj_op_args", None): + kwargs["extra_adj_op_args"] += ["-i"] + else: + kwargs["extra_adj_op_args"] = ["-i"] + + super().__init__( + samples_, + shape, + density, + n_coils=n_coils, + n_batchs=n_batchs, + n_trans=1, + smaps=smaps, + squeeze_dims=squeeze_dims, + ) + + self.raw_op = RawBartNUFFT(samples_, shape, **kwargs) + + @property + def norm_factor(self): + """Normalization factor of the operator.""" + # return 1.0 + return np.sqrt(2 ** len(self.shape)) + + +def _readcfl(cfl_file, hdr_file=None): + """Read a pair of .cfl/.hdr file to get a complex numpy array. + + Adapted from the BART python cfl library. + + Parameters + ---------- + name : str + Name of the file to read. + + Returns + ------- + array : array_like + Complex array read from the file. + """ + basename = Path(cfl_file).with_suffix("") + if hdr_file is None: + hdr_file = basename.with_suffix(".hdr") + cfl_file = basename.with_suffix(".cfl") + + # get dims from .hdr + with open(hdr_file) as h: + h.readline() # skip + line = h.readline() + dims = [int(i) for i in line.split()] + + # remove singleton dimensions from the end + n = np.prod(dims) + dims_prod = np.cumprod(dims) + dims = dims[: np.searchsorted(dims_prod, n) + 1] + + # load data and reshape into dims + with open(cfl_file, "rb") as d: + a = np.fromfile(d, dtype=np.complex64, count=n) + return a.reshape(dims, order="F") + + +def _writecfl(array, cfl_file, hdr_file=None): + """Write a pair of .cfl/.hdr file representing a complex array. + + Adapted from the BART python cfl library. + + Parameters + ---------- + name : str + Name of the file to write. + array : array_like + Array to write to file. + + """ + basename = Path(cfl_file).with_suffix("") + if hdr_file is None: + hdr_file = basename.with_suffix(".hdr") + cfl_file = basename.with_suffix(".cfl") + + with open(hdr_file, "w") as h: + h.write("# Dimensions\n") + for i in array.shape: + h.write("%d " % i) + h.write("\n") + + size = np.prod(array.shape) * np.dtype(np.complex64).itemsize + + with open(cfl_file, "w+b") as d: + os.ftruncate(d.fileno(), size) + mm = mmap.mmap(d.fileno(), size, flags=mmap.MAP_SHARED, prot=mmap.PROT_WRITE) + if array.dtype != np.complex64: + array = array.astype(np.complex64) + mm.write(np.ascontiguousarray(array.T)) + mm.close() + + +def traj2cfl(traj, shape, basename): + """ + Export a trajectory defined in MRI-nufft to a BART compatible format. + + Parameters + ---------- + traj: array_like + trajectory array, shape (N_shot, N_points, 2 or 3) + shape: tuple + volume shape (FOV) + + The trajectory will be normalized to -(FOV-1)/2 +(FOV-1)/2, + and reshape to BART format. + """ + traj_ = traj * (np.array(shape) - 1) + if traj.shape[-1] == 2: + traj_3d = np.zeros(traj_.shape[:-1] + (3,), dtype=traj_.dtype) + traj_3d[..., :2] = traj_ + traj_ = traj_3d + else: + traj_ = traj_.astype(np.complex64) + traj_ = traj_[None, None, ...] + + _writecfl(traj_.T, basename) + + +def cfl2traj(basename, shape=None): + """Convert a trajectory BART file to a numpy array compatible with MRI-nufft. + + Parameters + ---------- + filename: str + Base filename for the trajectory + shape: optional + Shape of the Image domain. + """ + traj_raw = _readcfl(basename) + # Convert to float array and take only the real part + traj = np.ascontiguousarray(traj_raw.T.view("(2,)float32")[..., 0]) + if np.all(traj[..., -1] == 0): + warnings.warn("2D Trajectory Detected") + traj = traj[..., :-1] + if shape is None: + maxs = [np.max(traj[..., i]) for i in range(traj.shape[-1])] + mins = [np.min(traj[..., i]) for i in range(traj.shape[-1])] + shape = np.array(maxs) - np.array(mins) + warnings.warn(f"Estimated shape {shape}") + else: + shape = np.asarray(shape) - 1 + + traj /= np.asarray(shape) + return np.squeeze(traj) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index ea3af3e8..8d804b70 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -13,7 +13,7 @@ pin_memory, sizeof_fmt, ) -from ._cupy_kernels import sense_adj_mono, update_density +from ._cupy_kernels import update_density CUFINUFFT_AVAILABLE = CUPY_AVAILABLE try: @@ -237,7 +237,6 @@ def __init__( smaps=None, smaps_cached=False, verbose=False, - persist_plan=True, squeeze_dims=False, n_trans=1, **kwargs, @@ -293,9 +292,6 @@ def __init__( self.smaps = pin_memory(smaps.astype(self.cpx_dtype)) self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) - # Initialise NUFFT plans - self.persist_plan = persist_plan - self.raw_op = RawCufinufftPlan( self.samples, tuple(shape), @@ -328,9 +324,6 @@ def op(self, data, ksp_d=None): else: check_size(data, (self.n_batchs, self.n_coils, *self.shape)) data = data.astype(self.cpx_dtype) - if not self.persist_plan or self.raw_op.plans[2] is None: - self.raw_op._make_plan(2) - self.raw_op._set_pts(2) # Dispatch to special case. if self.uses_sense and is_cuda_array(data): @@ -342,25 +335,29 @@ def op(self, data, ksp_d=None): else: op_func = self._op_calibless_host ret = op_func(data, ksp_d) - if not self.persist_plan: - self.raw_op._destroy_plan(2) ret /= self.norm_factor return self._safe_squeeze(ret) def _op_sense_device(self, data, ksp_d=None): - # FIXME: add batch support. - ksp_d = ksp_d or cp.empty((self.n_coils, self.n_samples), dtype=self.cpx_dtype) - img_d = cp.asarray(data, dtype=self.cpx_dtype) - coil_img_d = cp.empty(self.shape, dtype=self.cpx_dtype) - for i in range(self.n_coils): - if self.smaps_cached: - coil_img_d = img_d * self.smaps[i] # sense forward + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + image_dataf = cp.reshape(data, (B, *XYZ)) + ksp_d = ksp_d or cp.empty((B * C, K), dtype=self.cpx_dtype) + smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + data_batched = image_dataf[idx_batch].reshape((T, *XYZ)) + if not self.smaps_cached: + smaps_batched.set(self.smaps[idx_coils].reshape((T, *XYZ))) else: - self._smap_d.set(self.smaps[i]) - coil_img_d = img_d * self._smap_d[i] # sense forward - self.__op(get_ptr(coil_img_d), get_ptr(ksp_d) + i * self.ksp_size) - return ksp_d + smaps_batched = self.smaps[idx_coils].reshape((T, *XYZ)) + data_batched *= smaps_batched + self.__op(get_ptr(data_batched), get_ptr(ksp_d[i * T : (i + 1) * T])) + + return ksp_d.reshape((B, C, K)) def _op_sense_host(self, data, ksp=None): T, B, C = self.n_trans, self.n_batchs, self.n_coils @@ -443,10 +440,6 @@ def adj_op(self, coeffs, img_d=None): Array in the same memory space of coeffs. (ie on cpu or gpu Memory). """ check_size(coeffs, (self.n_batchs, self.n_coils, self.n_samples)) - if not self.persist_plan or self.raw_op.plans[1] is None: - self.raw_op._make_plan(1) - self.raw_op._set_pts(1) - # Dispatch to special case. if self.uses_sense and is_cuda_array(coeffs): adj_op_func = self._adj_op_sense_device @@ -459,8 +452,6 @@ def adj_op(self, coeffs, img_d=None): ret = adj_op_func(coeffs, img_d) ret /= self.norm_factor - if not self.persist_plan: - self.raw_op._destroy_plan(1) return self._safe_squeeze(ret) def _adj_op_sense_device(self, coeffs, img_d=None): @@ -599,98 +590,158 @@ def data_consistency(self, image_data, obs_data): obs_data: array Observed data. """ - check_size(obs_data, (self.n_batchs, self.n_coils, self.n_samples)) - if self.uses_sense: - check_size(image_data, (self.n_batchs, *self.shape)) - else: - check_size(image_data, (self.n_batchs, self.n_coils, *self.shape)) - - if not self.persist_plan or self.raw_op.plans[1] is None: - self.raw_op._make_plan(1) - self.raw_op._set_pts(1) + B, C = self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + check_size(obs_data, (B, C, K)) if self.uses_sense: - dc_func = self._data_consistency_sense + check_size(image_data, (B, *XYZ)) else: - dc_func = self._data_consistency_calibless - ret = dc_func(image_data, obs_data) - if not self.persist_plan: - self.raw_op._destroy_plan(1) + check_size(image_data, (B, C, *XYZ)) + + if self.uses_sense and is_host_array(image_data): + grad_func = self._grad_sense_host + elif self.uses_sense and is_cuda_array(image_data): + grad_func = self._grad_sense_device + elif not self.uses_sense and is_host_array(image_data): + grad_func = self._grad_calibless_host + elif not self.uses_sense and is_cuda_array(image_data): + grad_func = self._grad_calibless_device + else: + raise ValueError("No suitable gradient function found.") + ret = grad_func(image_data, obs_data) return self._safe_squeeze(ret) - def _data_consistency_sense(self, image_data, obs_data): - img_d = cp.array(image_data, copy=True) - coil_img_d = cp.empty(self.shape, dtype=self.cpx_dtype) - coil_ksp_d = cp.empty(self.n_samples, dtype=self.cpx_dtype) - if is_host_array(obs_data): - coil_obs_data = cp.empty(self.n_samples, dtype=self.cpx_dtype) - obs_data_pinned = pin_memory(obs_data) - for i in range(self.n_coils): - cp.copyto(coil_img_d, img_d) - if self.smaps_cached: - coil_img_d *= self.smaps[i] - else: - self._smap_d.set(self._smaps[i]) - coil_img_d *= self._smap_d - self.__op(get_ptr(coil_img_d), get_ptr(coil_ksp_d)) - coil_obs_data = cp.asarray(obs_data_pinned[i]) - coil_ksp_d -= coil_obs_data - if self.uses_density: - coil_ksp_d *= self.density_d - self.__adj_op(get_ptr(coil_ksp_d), get_ptr(coil_img_d)) - if self.smaps_cached: - sense_adj_mono(img_d, coil_img_d, self.smaps[i]) - else: - sense_adj_mono(img_d, coil_img_d, self._smap_d) - del obs_data_pinned - return img_d.get() - - for i in range(self.n_coils): - cp.copyto(coil_img_d, img_d) - if self.smaps_cached: - coil_img_d *= self.smaps[i] + def _grad_sense_host(self, image_data, obs_data): + """Gradient computation when all data is on host.""" + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + image_dataf = np.reshape(image_data, (B, *XYZ)) + obs_dataf = np.reshape(obs_data, (B * C, K)) + + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + obs_batched = cp.empty((T, K), dtype=self.cpx_dtype) + + grad_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) + grad = np.empty((B, *XYZ), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + data_batched.set(image_dataf[idx_batch].reshape((T, *XYZ))) + obs_batched.set(obs_dataf[i * T : (i + 1) * T]) + + if not self.smaps_cached: + smaps_batched.set(self.smaps[idx_coils].reshape((T, *XYZ))) else: - self._smap_d.set(self._smaps[i]) - coil_img_d *= self._smap_d - self.__op(get_ptr(coil_img_d), get_ptr(coil_ksp_d)) - coil_ksp_d -= obs_data[i] + smaps_batched = self.smaps[idx_coils].reshape((T, *XYZ)) + data_batched *= smaps_batched + self.__op(get_ptr(data_batched), get_ptr(ksp_batched)) + + ksp_batched /= self.norm_factor + ksp_batched -= obs_batched + if self.uses_density: - coil_ksp_d *= self.density_d - self.__adj_op(get_ptr(coil_ksp_d), get_ptr(coil_img_d)) - if self.smaps_cached: - sense_adj_mono(img_d, coil_img_d, self.smaps[i]) + ksp_batched *= self.density + self.__adj_op(get_ptr(ksp_batched), get_ptr(data_batched)) + + for t, b in enumerate(idx_batch): + grad_d[b, :] += data_batched[t] * smaps_batched[t].conj() + grad_d /= self.norm_factor + grad = grad_d.get() + grad = grad.reshape((B, 1, *XYZ)) + return grad + + def _grad_sense_device(self, image_data, obs_data): + """Gradient computation when all data is on device.""" + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + image_dataf = cp.reshape(image_data, (B, *XYZ)) + obs_dataf = cp.reshape(obs_data, (B * C, K)) + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + grad = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) + + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + data_batched.set(image_dataf[i * T : (i + 1) * T]) + if not self.smaps_cached: + smaps_batched.set(self.smaps[idx_coils].reshape((T, *XYZ))) else: - sense_adj_mono(img_d, coil_img_d, self._smap_d) - return img_d + smaps_batched = self.smaps[idx_coils].reshape((T, *XYZ)) + data_batched *= smaps_batched + self.__op(get_ptr(data_batched), get_ptr(ksp_batched)) + ksp_batched /= self.norm_factor + ksp_batched -= obs_dataf[i * T : (i + 1) * T] + + if self.uses_density: + ksp_batched *= self.density + self.__adj_op(get_ptr(ksp_batched), get_ptr(data_batched)) + + for t, b in enumerate(idx_batch): + # TODO write a kernel for that. + grad[b] += data_batched[t] * smaps_batched[t].conj() + grad = grad.reshape((B, 1, *XYZ)) + grad /= self.norm_factor + return grad + + def _grad_calibless_host(self, image_data, obs_data): + """Calibrationless Gradient computation when all data is on host.""" + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + image_dataf = np.reshape(image_data, (B * C, *XYZ)) + obs_dataf = np.reshape(obs_data, (B * C, K)) - def _data_consistency_calibless(self, image_data, obs_data): - if is_cuda_array(image_data): - img_d = cp.empty((self.n_coils, *self.shape), dtype=self.cpx_dtype) - ksp_d = cp.empty(self.n_samples, dtype=self.cpx_dtype) - for i in range(self.n_coils): - self.__op(get_ptr(image_data) + i * self.img_size, get_ptr(ksp_d)) - ksp_d /= self.norm_factor - ksp_d -= obs_data[i] - if self.uses_density: - ksp_d *= self.density_d - self.__adj_op(get_ptr(ksp_d), get_ptr(img_d) + i * self.img_size) - return img_d / self.norm_factor - - img_d = cp.empty(self.shape, dtype=self.cpx_dtype) - img = np.zeros((self.n_coils, *self.shape), dtype=self.cpx_dtype) - ksp_d = cp.empty(self.n_samples, dtype=self.cpx_dtype) - obs_d = cp.empty(self.n_samples, dtype=self.cpx_dtype) - for i in range(self.n_coils): - img_d.set(image_data[i]) - obs_d.set(obs_data[i]) - self.__op(get_ptr(img_d), get_ptr(ksp_d)) - ksp_d /= self.norm_factor - ksp_d -= obs_d + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + obs_batched = cp.empty((T, K), dtype=self.cpx_dtype) + + grad = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) + + for i in range(B * C // T): + data_batched.set(image_dataf[i * T : (i + 1) * T]) + obs_batched.set(obs_dataf[i * T : (i + 1) * T]) + self.__op(get_ptr(data_batched), get_ptr(ksp_batched)) + ksp_batched /= self.norm_factor + ksp_batched -= obs_batched + if self.uses_density: + ksp_batched *= self.density + self.__adj_op(get_ptr(ksp_batched), get_ptr(data_batched)) + data_batched /= self.norm_factor + grad[i * T : (i + 1) * T] = data_batched.get() + grad = grad.reshape((B, C, *XYZ)) + return grad + + def _grad_calibless_device(self, image_data, obs_data): + """Calibrationless Gradient computation when all data is on device.""" + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape + + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + + grad = cp.empty((B * C, *XYZ), dtype=self.cpx_dtype) + + for i in range(B * C // T): + data_batched.set(image_data[i * T : (i + 1) * T]) + self.__op(get_ptr(data_batched), get_ptr(ksp_batched)) + ksp_batched /= self.norm_factor + ksp_batched -= obs_data[i * T : (i + 1) * T] if self.uses_density: - ksp_d *= self.density_d - self.__adj_op(get_ptr(ksp_d), get_ptr(img_d)) - cp.asnumpy(img_d, out=img[i]) - return img / self.norm_factor + ksp_batched *= self.density + self.__adj_op(get_ptr(ksp_batched), get_ptr(data_batched)) + grad[i * T : (i + 1) * T] = data_batched + grad = grad.reshape((B, C, *XYZ)) + grad /= self.norm_factor + return grad def _safe_squeeze(self, arr): """Squeeze the first two dimensions of shape of the operator.""" diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index cac7d1ee..659e5dae 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -1,9 +1,8 @@ """Finufft interface.""" import numpy as np -import warnings -from ..base import FourierOperatorBase, proper_trajectory +from ..base import FourierOperatorCPU, proper_trajectory FINUFFT_AVAILABLE = True try: @@ -24,7 +23,7 @@ def __init__( **kwargs, ): self.shape = shape - self.samples = samples + self.samples = proper_trajectory(np.asfortranarray(samples), normalize="pi") self.ndim = len(shape) self.eps = float(eps) self.n_trans = n_trans @@ -53,16 +52,22 @@ def _set_pts(self, typ): fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) self.plans[typ].setpts(*fpts_axes) - def adj_op(self, coeff_data, grid_data): + def adj_op(self, coeffs_data, grid_data): """Type 1 transform. Non Uniform to Uniform.""" - return self.plans[1].execute(coeff_data, grid_data) + if self.n_trans == 1: + grid_data = grid_data.reshape(self.shape) + coeffs_data = coeffs_data.reshape(len(self.samples)) + return self.plans[1].execute(coeffs_data, grid_data) - def op(self, coeff_data, grid_data): + def op(self, coeffs_data, grid_data): """Type 2 transform. Uniform to non-uniform.""" - return self.plans[2].execute(grid_data, coeff_data) + if self.n_trans == 1: + grid_data = grid_data.reshape(self.shape) + coeffs_data = coeffs_data.reshape(len(self.samples)) + return self.plans[2].execute(grid_data, coeffs_data) -class MRIfinufft(FourierOperatorBase): +class MRIfinufft(FourierOperatorCPU): """MRI Transform Operator using finufft. Parameters @@ -103,240 +108,23 @@ def __init__( n_batchs=1, n_trans=1, smaps=None, - squeeze_dims=False, + squeeze_dims=True, + **kwargs, ): - super().__init__() - - self.shape = shape - - # we will access the samples by their coordinate first. - self.samples = proper_trajectory(np.asfortranarray(samples), normalize="pi") - self.dtype = self.samples.dtype + super().__init__( + samples, + shape, + density, + n_coils=n_coils, + n_batchs=n_batchs, + n_trans=n_trans, + smaps=smaps, + squeeze_dims=squeeze_dims, + ) - # Density Compensation Setup - if density is True: - self.density = self.estimate_density(self.samples, shape) - elif isinstance(density, np.ndarray): - if len(density) != len(self.samples): - raise ValueError( - "Density array and samples array should have the same length." - ) - self.density = np.asfortranarray(density) - else: - self.density = None - # Multi Coil Setup - if n_coils < 1: - raise ValueError("n_coils should be ≥ 1") - self.n_coils = n_coils - self.smaps = smaps - self.n_batchs = n_batchs - self.n_trans = n_trans - self.squeeze_dims = squeeze_dims - # Initialise NUFFT plans self.raw_op = RawFinufftPlan( - self.samples, - tuple(shape), + samples, + shape, n_trans=n_trans, + **kwargs, ) - - def op(self, data, ksp=None): - r"""Non Cartesian MRI forward operator. - - Parameters - ---------- - data: np.ndarray - The uniform (2D or 3D) data in image space. - - Returns - ------- - Results array on the same device as data. - - Notes - ----- - this performs for every coil \ell: - ..math:: \mathcal{F}\mathcal{S}_\ell x - """ - if data.dtype != self.cpx_dtype: - warnings.warn( - f"Data should be of dtype {self.cpx_dtype} (is {data.dtype}). " - "Casting it for you." - ) - data = data.astype(self.cpx_dtype) - # sense - if self.uses_sense: - ret = self._op_sense(data, ksp) - # calibrationless or monocoil. - else: - ret = self._op_calibless(data, ksp) - ret /= self.norm_factor - return self._safe_squeeze(ret) - - def _op_sense(self, data, ksp=None): - T, B, C = self.n_trans, self.n_batchs, self.n_coils - K, XYZ = self.n_samples, self.shape - dataf = data.reshape((B, *XYZ)) - if ksp is None: - ksp = np.empty((B * C, K), dtype=self.cpx_dtype) - for i in range(B * C // T): - idx_coils = np.arange(i * T, (i + 1) * T) % C - idx_batch = np.arange(i * T, (i + 1) * T) // C - coil_img = self.smaps[idx_coils].copy().reshape((T, *XYZ)) - coil_img *= dataf[idx_batch] - self._op(coil_img, ksp[i * T : (i + 1) * T]) - ksp = ksp.reshape((B, C, K)) - return ksp - - def _op_calibless(self, data, ksp=None): - if ksp is None: - ksp = np.empty( - (self.n_batchs * self.n_coils, self.n_samples), dtype=self.cpx_dtype - ) - dataf = np.reshape(data, (self.n_batchs * self.n_coils, *self.shape)) - if self.n_trans == 1: - for i in range(self.n_coils * self.n_batchs): - self._op(dataf[i], ksp[i]) - else: - for i in range((self.n_coils * self.n_batchs) // self.n_trans): - self._op( - dataf[i * self.n_trans : (i + 1) * self.n_trans], - ksp[i * self.n_trans : (i + 1) * self.n_trans], - ) - ksp = ksp.reshape((self.n_batchs, self.n_coils, self.n_samples)) - return ksp - - def _op(self, image, coeffs): - if self.n_trans == 1: - image = image.reshape(self.shape) - coeffs = coeffs.reshape(self.n_samples) - self.raw_op.op(coeffs, image) - - def adj_op(self, coeffs, img=None): - """Non Cartesian MRI adjoint operator. - - Parameters - ---------- - coeffs: np.array or GPUArray - - Returns - ------- - Array in the same memory space of coeffs. (ie on cpu or gpu Memory). - """ - if coeffs.dtype != self.cpx_dtype: - warnings.warn( - f"coeffs should be of dtype {self.cpx_dtype}. Casting it for you." - ) - coeffs = coeffs.astype(self.cpx_dtype) - if self.uses_sense: - ret = self._adj_op_sense(coeffs, img) - # calibrationless or monocoil. - else: - ret = self._adj_op_calibless(coeffs, img) - ret /= self.norm_factor - return self._safe_squeeze(ret) - - def _adj_op_sense(self, coeffs, img=None): - T, B, C = self.n_trans, self.n_batchs, self.n_coils - K, XYZ = self.n_samples, self.shape - if img is None: - img = np.zeros((B, *XYZ), dtype=self.cpx_dtype) - coeffs_flat = coeffs.reshape((B * C, K)) - img_batched = np.zeros((T, *XYZ), dtype=self.cpx_dtype) - for i in range(B * C // T): - idx_coils = np.arange(i * T, (i + 1) * T) % C - idx_batch = np.arange(i * T, (i + 1) * T) // C - self._adj_op(coeffs_flat[i * T : (i + 1) * T], img_batched) - img_batched *= self.smaps[idx_coils].conj() - for t, b in enumerate(idx_batch): - img[b] += img_batched[t] - img = img.reshape((B, 1, *XYZ)) - return img - - def _adj_op_calibless(self, coeffs, img=None): - T, B, C = self.n_trans, self.n_batchs, self.n_coils - K, XYZ = self.n_samples, self.shape - if img is None: - img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) - coeffs_f = np.reshape(coeffs, (B * C, K)) - for i in range((B * C) // T): - self._adj_op(coeffs_f[i * T : (i + 1) * T], img[i * T : (i + 1) * T]) - - img = img.reshape((B, C, *XYZ)) - return img - - def _adj_op(self, coeffs, image): - if self.density is not None: - coeffs2 = coeffs.copy() - for i in range(self.n_trans): - coeffs2[i * self.n_samples : (i + 1) * self.n_samples] *= self.density - else: - coeffs2 = coeffs - if self.n_trans == 1: - image = image.reshape(self.shape) - coeffs2 = coeffs2.reshape(self.n_samples) - self.raw_op.adj_op(coeffs2, image) - - @property - def norm_factor(self): - """Norm factor of the operator.""" - return np.sqrt(np.prod(self.shape) * (2 ** len(self.shape))) - - def data_consistency(self, image_data, obs_data): - """Compute the gradient estimation directly on gpu. - - This mixes the op and adj_op method to perform F_adj(F(x-y)) - on a per coil basis. By doing the computation coil wise, - it uses less memory than the naive call to adj_op(op(x)-y) - - Parameters - ---------- - image: array - Image on which the gradient operation will be evaluated. - N_coil x Image shape is not using sense. - obs_data: array - Observed data. - """ - if self.uses_sense: - return self._data_consistency_sense(image_data, obs_data) - return self._data_consistency_calibless(image_data, obs_data) - - def _data_consistency_sense(self, image_data, obs_data): - img = np.empty_like(image_data) - coil_img = np.empty(self.shape, dtype=image_data.dtype) - coil_ksp = np.empty(self.n_samples, dtype=obs_data.dtype) - for i in range(self.n_coils): - np.copyto(coil_img, img) - coil_img *= self._smap - self._op(coil_img, coil_ksp) - coil_ksp /= self.norm_factor - coil_ksp -= obs_data[i] - if self.uses_density: - coil_ksp *= self.density_d - self._adj_op(coil_ksp, coil_img) - coil_img /= self.norm_factor - img += coil_img * self._smaps[i].conjugate() - return img - - def _data_consistency_calibless(self, image_data, obs_data): - img = np.empty((self.n_coils, *self.shape), dtype=image_data.dtype) - ksp = np.empty(self.n_samples, dtype=obs_data.dtype) - for i in range(self.n_coils): - self._op(image_data[i], ksp) - ksp /= self.norm_factor - ksp -= obs_data[i] - if self.uses_density: - ksp *= self.density_d - self._adj_op(ksp, img[i]) - return img / self.norm_factor - - def _safe_squeeze(self, arr): - """Squeeze the first two dimensions of shape of the operator.""" - if self.squeeze_dims: - try: - arr = arr.squeeze(axis=1) - except ValueError: - pass - try: - arr = arr.squeeze(axis=0) - except ValueError: - pass - return arr diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index a081364d..f1cee6ab 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -10,7 +10,7 @@ GPUNUFFT_AVAILABLE = False -class gpuNUFFT: +class RawGpuNUFFT: """GPU implementation of N-D non-uniform fast Fourier Transform class. Attributes @@ -36,8 +36,10 @@ def __init__( kernel_width=3, sector_width=8, osf=2, + upsampfac=None, balance_workload=True, smaps=None, + pinned_smaps=None, ): """Initialize the 'NUFFT' class. @@ -60,10 +62,19 @@ def __init__( sector width to use osf: int default 2 oversampling factor (usually between 1 and 2) + upsampfac: int default 2 + Same as osf. balance_workload: bool default True whether the workloads need to be balanced smaps: np.ndarray default None Holds the sensitivity maps for SENSE reconstruction + pinned_smaps: np.ndarray default None + Pinned memory array for the smaps. + + Notes + ----- + pinned_smaps status (pinned or not) is not checked here, but in the C++ code. + If its not pinned, then an extra copy will be triggered. """ if GPUNUFFT_AVAILABLE is False: raise ValueError( @@ -76,18 +87,26 @@ def __init__( self.samples = proper_trajectory(samples, normalize="unit") if density_comp is None: density_comp = np.ones(samples.shape[0]) - if smaps is None: - self.uses_sense = False + + self.uses_sense = True + if smaps is not None and pinned_smaps is None: + # no pinning provided, we will pin it in the C++ code + pinned_smaps = smaps.T.reshape(-1, n_coils) + elif smaps is not None and pinned_smaps is not None: + # Pinned memory space exists, we will overwrite it + np.copyto(pinned_smaps, smaps.T.reshape(-1, n_coils)) else: - smaps = np.asarray( - [np.reshape(smap_ch.T, smap_ch.size) for smap_ch in smaps] - ).T - self.uses_sense = True + # No smaps provided, we will not use SENSE + self.uses_sense = False + + if upsampfac is not None: + osf = upsampfac + self.operator = NUFFTOp( np.reshape(samples, samples.shape[::-1], order="F"), shape, n_coils, - smaps, + pinned_smaps, density_comp, kernel_width, sector_width, @@ -174,6 +193,8 @@ class MRIGpuNUFFT(FourierOperatorBase): if True, the density compensation is estimated from the samples locations. If an array is passed, it is used as the density compensation. + squeeze_dims: bool default True + This has no effect, gpuNUFFT always squeeze the data. smaps: np.ndarray default None Holds the sensitivity maps for SENSE reconstruction. kwargs: extra keyword args @@ -184,7 +205,17 @@ class MRIGpuNUFFT(FourierOperatorBase): backend = "gpunufft" available = GPUNUFFT_AVAILABLE - def __init__(self, samples, shape, n_coils=1, density=None, smaps=None, **kwargs): + def __init__( + self, + samples, + shape, + n_coils=1, + density=None, + smaps=None, + squeeze_dims=False, + eps=1e-3, + **kwargs, + ): if GPUNUFFT_AVAILABLE is False: raise ValueError( "gpuNUFFT library is not installed, " @@ -203,12 +234,13 @@ def __init__(self, samples, shape, n_coils=1, density=None, smaps=None, **kwargs else: self.density = None self.kwargs = kwargs - self.impl = gpuNUFFT( + self.impl = RawGpuNUFFT( samples=self.samples, shape=self.shape, n_coils=self.n_coils, density_comp=self.density, smaps=smaps, + kernel_width=kwargs.get("kernel_width", -int(np.log10(eps))), **self.kwargs, ) @@ -242,10 +274,6 @@ def adj_op(self, coeffs, *args): """ return self.impl.adj_op(coeffs, *args) - def data_consistency(self, data, obs_data): - """Compute the data consistency gradient direction.""" - return self.adj_op(self.op(data) - obs_data) - @property def uses_sense(self): """Return True if the Fourier Operator uses the SENSE method.""" diff --git a/src/mrinufft/operators/interfaces/nfft.py b/src/mrinufft/operators/interfaces/nfft.py index 1b433e46..1e8427ad 100644 --- a/src/mrinufft/operators/interfaces/nfft.py +++ b/src/mrinufft/operators/interfaces/nfft.py @@ -58,11 +58,13 @@ class MRInfft(FourierOperatorCPU): backend = "pynfft" available = PYNFFT_AVAILABLE - def __init__(self, samples, shape, n_coils=1, smaps=None): + def __init__(self, samples, shape, n_coils=1, n_batchs=1, smaps=None): super().__init__( samples, shape, n_coils=n_coils, + n_batchs=n_batchs, + n_trans=1, smaps=smaps, raw_op=None, # is set later, after normalizing samples. ) diff --git a/src/mrinufft/operators/interfaces/sigpy.py b/src/mrinufft/operators/interfaces/sigpy.py new file mode 100644 index 00000000..d836667b --- /dev/null +++ b/src/mrinufft/operators/interfaces/sigpy.py @@ -0,0 +1,131 @@ +"""Sigpy NUFFT interface. + +The SigPy NUFFT is fully implemented in Python. +""" +import warnings + +import numpy as np +from ..base import FourierOperatorCPU, proper_trajectory + + +SIGPY_AVAILABLE = True +try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import sigpy.fourier as sgf +except ImportError: + SIGPY_AVAILABLE = False + + +class RawSigpyNUFFT: + """Raw interface to SigPy output /= width**ndim NUFFT. + + Parameters + ---------- + samples: np.ndarray + the kspace sample locations in the Fourier domain, + normalized between -0.5 and 0.5 + shape: tuple of int + shape of the image + oversamp: float default 1.25 + oversampling factor + width: int default 4 + interpolation kernel width (usually 3 to 7) + upsampfac: float, default 1.25 + Same as oversamp + eps: float, default 1e-4 + Other way of specifiying width. + """ + + def __init__( + self, + samples, + shape, + oversamp=1.25, + width=4, + eps=None, + upsampfac=None, + n_trans=1, + **kwargs, + ): + if upsampfac is not None: + oversamp = upsampfac + if eps is not None: + width = -int(np.log10(eps)) + + self.shape = shape + shape = np.array(shape) + # scale in FOV/2 units + self.samples = samples * shape + self.n_trans = n_trans + self._oversamp = oversamp + self._width = width + + def op(self, coeffs_data, grid_data): + """Forward Operator.""" + grid_data_ = grid_data.reshape(self.n_trans, *self.shape) + ret = sgf.nufft( + grid_data_, + self.samples, + oversamp=self._oversamp, + width=self._width, + ) + ret = ret.reshape(self.n_trans, len(self.samples)) + np.copyto(coeffs_data, ret) + return coeffs_data + + def adj_op(self, coeffs_data, grid_data): + """Adjoint Operator.""" + coeffs_data_ = coeffs_data.reshape(self.n_trans, len(self.samples)) + ret = sgf.nufft_adjoint( + coeffs_data_, + self.samples, + oshape=(self.n_trans, *self.shape), + oversamp=self._oversamp, + width=self._width, + ) + ret = ret.reshape(self.n_trans, *self.shape) + np.copyto(grid_data, ret) + return grid_data + + +class MRISigpyNUFFT(FourierOperatorCPU): + """NUFFT using SigPy. + + This is a wrapper around the SigPy NUFFT operator. + """ + + backend = "sigpy" + available = SIGPY_AVAILABLE + + def __init__( + self, + samples, + shape, + density=False, + n_coils=1, + n_batchs=1, + n_trans=1, + smaps=None, + squeeze_dims=True, + **kwargs, + ): + samples_ = proper_trajectory(samples, normalize="unit") + + super().__init__( + samples_, + shape, + density=density, + n_coils=n_coils, + n_batchs=n_batchs, + n_trans=n_trans, + smaps=smaps, + squeeze_dims=squeeze_dims, + ) + + self.raw_op = RawSigpyNUFFT(samples_, shape, n_trans=n_trans, **kwargs) + + @property + def norm_factor(self): + """Normalization factor of the operator.""" + return np.sqrt(2 ** len(self.shape)) diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index f1c45cef..0f1e0397 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -23,9 +23,9 @@ def case_random2D(self, M=1000, N=64, pdf="uniform", seed=0): def case_random3D(self, M=200000, N=64, pdf="uniform", seed=0): """Create a random 3D trajectory.""" np.random.seed(seed) - samples = np.random.randn(M, 3) - samples /= samples.max() - samples -= 0.5 + samples = sp.stats.truncnorm(-3, 3, loc=0, scale=0.16).rvs(size=M * 3) + samples = samples.reshape(M, 3) + print(samples.min(), samples.max()) return samples, (N, N, N) def case_radial2D(self, Nc=10, Ns=500, N=64): diff --git a/tests/conftest.py b/tests/conftest.py index d8ca3b1c..473b54d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,9 +21,7 @@ def pytest_addoption(parser): def pytest_configure(config): """Configure hook for pytest.""" - print("Available backends:") - for backend in list_backends(): - print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}") + available = {b: FourierOperatorBase.interfaces[b][0] for b in list_backends()} if selected := config.getoption("backend"): # hijacks the availability of interfaces: @@ -38,9 +36,13 @@ def pytest_configure(config): True, FourierOperatorBase.interfaces[ref_backend][1], ) - print("Selected backends:") - for backend in list_backends(): - print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}") + selected = {b: FourierOperatorBase.interfaces[b][0] for b in list_backends()} + + available[ref_backend] = "REF" + selected[ref_backend] = "REF" + print(f"{'backends':>20}: {'avail':>5} {'select':<5}") + for b in list_backends(): + print(f"{b:>20}: {str(available[b]):>5} {str(selected[b]):>5}") # # for test directly parametrized by a backend diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 957faca4..04c64b8d 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,4 +1,11 @@ -#!/usr/bin/env python3 +"""Helper functions for testing the operators.""" -from .asserts import assert_almost_allclose +from .asserts import assert_almost_allclose, assert_correlate from .factories import kspace_from_op, image_from_op + +__all__ = [ + "assert_almost_allclose", + "assert_correlate", + "kspace_from_op", + "image_from_op", +] diff --git a/tests/helpers/asserts.py b/tests/helpers/asserts.py index cb6362b6..5d4e932b 100644 --- a/tests/helpers/asserts.py +++ b/tests/helpers/asserts.py @@ -3,6 +3,7 @@ import numpy as np import numpy.testing as npt +import scipy as sp def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False): @@ -43,3 +44,17 @@ def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False): e.message += "\nMismatched elements: " e.message += f"{np.sum(~val)} > {mismatch}(={mismatch_perc*100:.2f}%)" raise e + + +def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3): + """Assert the correlation between two arrays.""" + slope_reg, intercept, rvalue, stderr, intercept_stderr = sp.stats.linregress( + a.flatten(), b.flatten() + ) + abs_slope_reg = abs(slope_reg) + if abs(abs_slope_reg - slope) > slope_err: + raise AssertionError( + f"Slope {abs_slope_reg} != {slope} +- {slope_err}\n r={rvalue}," + f"intercept={intercept}, stderr={stderr}, " + f"intercept_stderr={intercept_stderr}" + ) diff --git a/tests/test_batch.py b/tests/test_batch.py index 93840f2b..29c18bb1 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -4,8 +4,9 @@ """ import numpy as np import numpy.testing as npt +import pytest from pytest_cases import parametrize_with_cases, parametrize, fixture - +from helpers import assert_correlate from mrinufft import get_operator from case_trajectories import CasesTrajectories @@ -44,9 +45,11 @@ def operator( if sense: smaps = 1j * np.random.rand(n_coils, *shape) smaps += np.random.rand(n_coils, *shape) + smaps = smaps.astype(np.complex64) smaps /= np.linalg.norm(smaps, axis=0) else: smaps = None + kspace_locs = kspace_locs.astype(np.float32) return get_operator(backend)( kspace_locs, shape, @@ -54,6 +57,7 @@ def operator( smaps=smaps, n_batchs=n_batch, n_trans=n_trans, + squeeze_dims=False, ) @@ -86,7 +90,7 @@ def kspace_data(operator): return kspace -def test_batch_type2(operator, flat_operator, image_data): +def test_batch_op(operator, flat_operator, image_data): """Test the batch type 2 (forward).""" kspace_data = operator.op(image_data) @@ -106,7 +110,7 @@ def test_batch_type2(operator, flat_operator, image_data): npt.assert_array_almost_equal_nulp(kspace_data, kspace_flat) -def test_batch_type1(operator, flat_operator, kspace_data): +def test_batch_adj_op(operator, flat_operator, kspace_data): """Test the batch type 1 (adjoint).""" kspace_flat = kspace_data.reshape(-1, operator.n_coils, operator.n_samples) image_flat = [None] * operator.n_batchs @@ -126,3 +130,50 @@ def test_batch_type1(operator, flat_operator, kspace_data): image_data = operator.adj_op(kspace_data) # Reduced accuracy for the GPU cases... npt.assert_allclose(image_data, image_flat, atol=1e-3, rtol=1e-3) + + +def test_data_consistency(operator, image_data, kspace_data): + """Test the data consistency operation.""" + # image_data = np.zeros_like(image_data) + res = operator.data_consistency(image_data, kspace_data) + tmp = operator.op(image_data) + res2 = operator.adj_op(tmp - kspace_data) + + # npt.assert_allclose(res.squeeze(), res2.squeeze(), atol=1e-4, rtol=1e-1) + res = res.reshape(-1, *operator.shape) + res2 = res2.reshape(-1, *operator.shape) + + slope_err = 1e-3 + # FIXME 2D Sense is not very accurate... + if len(operator.shape) == 2 and operator.uses_sense: + slope_err = 1e-1 + + for i in range(len(res)): + assert_correlate(res[i], res2[i], slope_err=slope_err) + + +def test_data_consistency_readonly(operator, image_data, kspace_data): + """Test that the data consistency does not modify the input parameters data.""" + kspace_tmp = kspace_data.copy() + image_tmp = image_data.copy() + kspace_tmp.setflags(write=False) + image_tmp.setflags(write=False) + operator.data_consistency(image_data, kspace_tmp) + npt.assert_equal(kspace_tmp, kspace_data) + npt.assert_equal(image_tmp, image_data) + + +def test_gradient_lipschitz(operator, image_data, kspace_data): + """Test the gradient lipschitz constant.""" + C = 1 if operator.uses_sense else operator.n_coils + img = image_data.copy().reshape(operator.n_batchs, C, *operator.shape) + for _ in range(10): + grad = operator.data_consistency(img, kspace_data) + norm = np.linalg.norm(grad) + grad /= norm + np.copyto(img, grad) + norm_prev = norm + + # TODO: check that the value is "not too far" from 1 + # TODO: to do the same with density compensation + assert (norm - norm_prev) / norm_prev < 1e-3 diff --git a/tests/test_density.py b/tests/test_density.py deleted file mode 100644 index 7ca92e0c..00000000 --- a/tests/test_density.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Test the density compensation estimations.""" -import numpy.testing as npt -from pytest_cases import parametrize_with_cases, parametrize - -from mrinufft.trajectories.density import pipe -from case_trajectories import CasesTrajectories - - -@parametrize_with_cases( - "kspace_traj, shape", - cases=[CasesTrajectories.case_radial3D], -) -@parametrize("backend", ["cufinufft", "tensorflow"]) -def test_density_pipe(kspace_traj, shape, backend): - """Test the density compensation estimations.""" - density = pipe(kspace_traj, shape, backend=backend, num_iter=20, tol=2e-7).get() - density_ref = pipe(kspace_traj, shape, backend=backend, num_iter=25, tol=2e-7).get() - - # TODO: get tighter bounds. - npt.assert_allclose(density, density_ref, atol=1, rtol=1) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 3e91fd14..6a5cdb11 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -20,10 +20,12 @@ @parametrize( "backend", [ + "bart", "pynfft", "finufft", "cufinufft", "gpunufft", + "sigpy", ], ) @parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories) @@ -40,7 +42,7 @@ def operator( @fixture(scope="session", autouse=True) def ref_backend(request): - """get the reference backend from the CLI""" + """Get the reference backend from the CLI.""" return request.config.getoption("ref") @@ -94,39 +96,4 @@ def test_interfaces_autoadjoint(operator): leftadjoint = np.vdot(img_data, image) reldiff[i] = abs(rightadjoint - leftadjoint) / abs(leftadjoint) print(reldiff) - assert np.mean(reldiff) < 1e-5 - - -def test_data_consistency_readonly(operator, image_data, kspace_data): - """Test that the data consistency does not modify the input parameters data.""" - kspace_tmp = kspace_data.copy() - image_tmp = image_data.copy() - kspace_tmp.setflags(write=False) - image_tmp.setflags(write=False) - operator.data_consistency(image_data, kspace_tmp) - npt.assert_equal(kspace_tmp, kspace_data) - npt.assert_equal(image_tmp, image_data) - - -def test_data_consistency(operator, image_data, kspace_data): - """Test the data consistency operation.""" - res = operator.data_consistency(image_data, kspace_data) - - res2 = operator.adj_op(operator.op(image_data) - kspace_data) - - npt.assert_allclose(res.squeeze(), res2.squeeze(), atol=1e-4, rtol=1e-1) - - -def test_gradient_lipschitz(operator, image_data, kspace_data): - """Test the gradient lipschitz constant.""" - img = image_data.copy() - for _ in range(10): - grad = operator.data_consistency(img, kspace_data) - norm = np.linalg.norm(grad) - grad /= norm - np.copyto(img, grad) - norm_prev = norm - - # TODO: check that the value is "not too far" from 1 - # TODO: to do the same with density compensation - assert (norm - norm_prev) / norm_prev < 1e-3 + assert np.mean(reldiff) < 5e-5 diff --git a/tests/test_stacked.py b/tests/test_stacked.py index bd6f46f6..ccc49c84 100644 --- a/tests/test_stacked.py +++ b/tests/test_stacked.py @@ -8,6 +8,7 @@ import numpy.testing as npt from pytest_cases import parametrize_with_cases, parametrize, fixture +from helpers import assert_correlate from mrinufft.operators.stacked import MRIStackedNUFFT, stacked2traj3d, traj3d2stacked from mrinufft import get_operator from case_trajectories import CasesTrajectories @@ -27,12 +28,10 @@ def operator(request, backend, kspace_locs, shape, z_index, n_batchs, n_coils, s if z_index == "full": z_index = np.arange(shape3d[-1]) - z_index_ = z_index elif z_index == "random_mask": - z_index = np.random.rand(shape3d[-1]) > 0.5 - z_index_ = np.arange(shape3d[-1])[z_index] + z_index = np.random.choice(shape3d[-1], shape3d[-1] // 2, replace=False) - kspace_locs3d = stacked2traj3d(kspace_locs, z_index_, shape[-1]) + kspace_locs3d = stacked2traj3d(kspace_locs, z_index, shape[-1]) # smaps support if sense: smaps = 1j * np.random.rand(n_coils, *shape3d) @@ -101,7 +100,7 @@ def test_stack_forward(operator, stacked_op, ref_op, image_data): """Compare the stack interface to the 3D NUFFT implementation.""" kspace_nufft = stacked_op.op(image_data).squeeze() kspace_ref = ref_op.op(image_data).squeeze() - npt.assert_allclose(kspace_nufft, kspace_ref, atol=1e-4, rtol=1e-1) + assert_correlate(kspace_nufft, kspace_ref) def test_stack_backward(operator, stacked_op, ref_op, kspace_data): @@ -109,7 +108,7 @@ def test_stack_backward(operator, stacked_op, ref_op, kspace_data): image_nufft = stacked_op.adj_op(kspace_data.copy()).squeeze() image_ref = ref_op.adj_op(kspace_data.copy()).squeeze() - npt.assert_allclose(image_nufft, image_ref, atol=1e-4, rtol=1e-1) + assert_correlate(image_nufft, image_ref) def test_stack_auto_adjoint(operator, stacked_op, kspace_data, image_data):