diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 32921e3d..7689441e 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -101,9 +101,11 @@ class or instance of class if args or kwargs are given. try: available, operator = FourierOperatorBase.interfaces[backend_name] except KeyError as exc: - raise ValueError("backend is not available") from exc + raise ValueError(f"backend {backend_name} is not available") from exc if not available: - raise ValueError("backend is registered, but dependencies are not met.") + raise ValueError( + f"backend {backend_name} is registered, but dependencies are not met." + ) if args or kwargs: operator = operator(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 0a585363..d8ca3b1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,10 +12,15 @@ def pytest_addoption(parser): default=[], help="NUFFT backend on which the tests are performed.", ) + parser.addoption( + "--ref", + default="pynfft", + help="Reference backend on which the tests are performed.", + ) def pytest_configure(config): - """Configuration hook for pytest.""" + """Configure hook for pytest.""" print("Available backends:") for backend in list_backends(): print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}") @@ -27,6 +32,15 @@ def pytest_configure(config): backend in selected, FourierOperatorBase.interfaces[backend][1], ) + # ensure the ref backend is available + ref_backend = config.getoption("ref") + FourierOperatorBase.interfaces[ref_backend] = ( + True, + FourierOperatorBase.interfaces[ref_backend][1], + ) + print("Selected backends:") + for backend in list_backends(): + print(f"{backend:<14}: {FourierOperatorBase.interfaces[backend][0]}") # # for test directly parametrized by a backend @@ -58,8 +72,13 @@ def pytest_generate_tests(metafunc): if v.argnames[0] == "backend" ][0] ] + print("backend detected", backend) # Only keep the callspec if the backend is available. if not check_backend(backend): callspec.marks.append( pytest.mark.skip(f"Backend {backend} not available.") ) + if backend == metafunc.config.getoption("ref"): + callspec.marks.append( + pytest.mark.skip("Not testing ref backend with self.") + ) diff --git a/tests/test_cpu.py b/tests/test_cpu.py deleted file mode 100644 index 8de82610..00000000 --- a/tests/test_cpu.py +++ /dev/null @@ -1 +0,0 @@ -"""Test for the CPU interfaces.""" diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index d98bc8f7..3e91fd14 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -38,11 +38,16 @@ def operator( return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None) +@fixture(scope="session", autouse=True) +def ref_backend(request): + """get the reference backend from the CLI""" + return request.config.getoption("ref") + + @fixture(scope="module") -@parametrize("backend", ["pynfft"]) -def nfft_ref_op(request, operator, backend="pynfft"): +def ref_operator(request, operator, ref_backend): """Generate a NFFT operator, matching the property of the first operator.""" - return get_operator(backend)( + return get_operator(ref_backend)( operator.samples, operator.shape, n_coils=operator.n_coils, smaps=operator.smaps ) @@ -59,18 +64,18 @@ def kspace_data(operator): return kspace_from_op(operator) -def test_interfaces_accuracy_forward(operator, image_data, nfft_ref_op): +def test_interfaces_accuracy_forward(operator, image_data, ref_operator): """Compare the interface to the raw NUDFT implementation.""" kspace_nufft = operator.op(image_data).squeeze() - kspace_ref = nfft_ref_op.op(image_data).squeeze() # FIXME: check with complex values ail + kspace_ref = ref_operator.op(image_data).squeeze() assert np.percentile(abs(kspace_nufft - kspace_ref) / abs(kspace_ref), 95) < 1e-1 -def test_interfaces_accuracy_backward(operator, kspace_data, nfft_ref_op): +def test_interfaces_accuracy_backward(operator, kspace_data, ref_operator): """Compare the interface to the raw NUDFT implementation.""" image_nufft = operator.adj_op(kspace_data.copy()).squeeze() - image_ref = nfft_ref_op.adj_op(kspace_data.copy()).squeeze() + image_ref = ref_operator.adj_op(kspace_data.copy()).squeeze() assert np.percentile(abs(image_nufft - image_ref) / abs(image_ref), 95) < 1e-1 diff --git a/tests/test_stacked_gpu.py b/tests/test_stacked_gpu.py index 6d57c437..0051a19a 100644 --- a/tests/test_stacked_gpu.py +++ b/tests/test_stacked_gpu.py @@ -112,16 +112,6 @@ def test_stack_backward(operator, stacked_op, ref_op, kspace_data): """Compare the stack interface to the 3D NUFFT implementation.""" image_nufft = stacked_op.adj_op(kspace_data.copy()).squeeze() image_ref = ref_op.adj_op(kspace_data.copy()).squeeze() - if stacked_op.n_coils > 1: - print( - np.max( - np.abs(image_nufft - image_ref).reshape( - stacked_op.n_batchs, stacked_op.n_coils, -1 - ), - axis=-1, - ) - ) - print(image_nufft.shape, image_ref.shape) npt.assert_allclose(image_nufft, image_ref, atol=1e-4, rtol=1e-1)