Skip to content

Commit

Permalink
feat(test): cleanup test
Browse files Browse the repository at this point in the history
* remove prints

* feat: better error message.

* feat(test): use --backend and --ref wisely.

* delete empty test file.
  • Loading branch information
paquiteau authored Oct 25, 2023
1 parent 11b1658 commit 156a42b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ class or instance of class if args or kwargs are given.
try:
available, operator = FourierOperatorBase.interfaces[backend_name]
except KeyError as exc:
raise ValueError("backend is not available") from exc
raise ValueError(f"backend {backend_name} is not available") from exc
if not available:
raise ValueError("backend is registered, but dependencies are not met.")
raise ValueError(
f"backend {backend_name} is registered, but dependencies are not met."
)

if args or kwargs:
operator = operator(*args, **kwargs)
Expand Down
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ def pytest_addoption(parser):
default=[],
help="NUFFT backend on which the tests are performed.",
)
parser.addoption(
"--ref",
default="pynfft",
help="Reference backend on which the tests are performed.",
)


def pytest_configure(config):
"""Configuration hook for pytest."""
"""Configure hook for pytest."""
print("Available backends:")
for backend in list_backends():
print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}")
Expand All @@ -27,6 +32,15 @@ def pytest_configure(config):
backend in selected,
FourierOperatorBase.interfaces[backend][1],
)
# ensure the ref backend is available
ref_backend = config.getoption("ref")
FourierOperatorBase.interfaces[ref_backend] = (
True,
FourierOperatorBase.interfaces[ref_backend][1],
)
print("Selected backends:")
for backend in list_backends():
print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}")


# # for test directly parametrized by a backend
Expand Down Expand Up @@ -58,8 +72,13 @@ def pytest_generate_tests(metafunc):
if v.argnames[0] == "backend"
][0]
]
print("backend detected", backend)
# Only keep the callspec if the backend is available.
if not check_backend(backend):
callspec.marks.append(
pytest.mark.skip(f"Backend {backend} not available.")
)
if backend == metafunc.config.getoption("ref"):
callspec.marks.append(
pytest.mark.skip("Not testing ref backend with self.")
)
1 change: 0 additions & 1 deletion tests/test_cpu.py

This file was deleted.

19 changes: 12 additions & 7 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@ def operator(
return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None)


@fixture(scope="session", autouse=True)
def ref_backend(request):
"""get the reference backend from the CLI"""
return request.config.getoption("ref")


@fixture(scope="module")
@parametrize("backend", ["pynfft"])
def nfft_ref_op(request, operator, backend="pynfft"):
def ref_operator(request, operator, ref_backend):
"""Generate a NFFT operator, matching the property of the first operator."""
return get_operator(backend)(
return get_operator(ref_backend)(
operator.samples, operator.shape, n_coils=operator.n_coils, smaps=operator.smaps
)

Expand All @@ -59,18 +64,18 @@ def kspace_data(operator):
return kspace_from_op(operator)


def test_interfaces_accuracy_forward(operator, image_data, nfft_ref_op):
def test_interfaces_accuracy_forward(operator, image_data, ref_operator):
"""Compare the interface to the raw NUDFT implementation."""
kspace_nufft = operator.op(image_data).squeeze()
kspace_ref = nfft_ref_op.op(image_data).squeeze()
# FIXME: check with complex values ail
kspace_ref = ref_operator.op(image_data).squeeze()
assert np.percentile(abs(kspace_nufft - kspace_ref) / abs(kspace_ref), 95) < 1e-1


def test_interfaces_accuracy_backward(operator, kspace_data, nfft_ref_op):
def test_interfaces_accuracy_backward(operator, kspace_data, ref_operator):
"""Compare the interface to the raw NUDFT implementation."""
image_nufft = operator.adj_op(kspace_data.copy()).squeeze()
image_ref = nfft_ref_op.adj_op(kspace_data.copy()).squeeze()
image_ref = ref_operator.adj_op(kspace_data.copy()).squeeze()

assert np.percentile(abs(image_nufft - image_ref) / abs(image_ref), 95) < 1e-1

Expand Down
10 changes: 0 additions & 10 deletions tests/test_stacked_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,6 @@ def test_stack_backward(operator, stacked_op, ref_op, kspace_data):
"""Compare the stack interface to the 3D NUFFT implementation."""
image_nufft = stacked_op.adj_op(kspace_data.copy()).squeeze()
image_ref = ref_op.adj_op(kspace_data.copy()).squeeze()
if stacked_op.n_coils > 1:
print(
np.max(
np.abs(image_nufft - image_ref).reshape(
stacked_op.n_batchs, stacked_op.n_coils, -1
),
axis=-1,
)
)
print(image_nufft.shape, image_ref.shape)
npt.assert_allclose(image_nufft, image_ref, atol=1e-4, rtol=1e-1)


Expand Down

0 comments on commit 156a42b

Please sign in to comment.