diff --git a/dask_cuda/tests/test_dask_cuda_worker.py b/dask_cuda/tests/test_dask_cuda_worker.py index bfcf91193..6f653a365 100644 --- a/dask_cuda/tests/test_dask_cuda_worker.py +++ b/dask_cuda/tests/test_dask_cuda_worker.py @@ -217,7 +217,6 @@ def test_pre_import(loop): # noqa: F811 assert all(imported) -@pytest.mark.xfail(reason="https://github.com/dask/distributed/issues/6320") @pytest.mark.timeout(20) @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) def test_pre_import_not_found(): diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index 4a027ad1d..fc835281a 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -7,7 +7,7 @@ from dask.distributed import Client from distributed.system import MEMORY_LIMIT -from distributed.utils_test import gen_test +from distributed.utils_test import gen_test, raises_with_cause from dask_cuda import CUDAWorker, LocalCUDACluster, utils from dask_cuda.initialize import initialize @@ -243,7 +243,7 @@ async def test_pre_import(): # Intentionally not using @gen_test to skip cleanup checks async def test_pre_import_not_found(): - with pytest.raises(ModuleNotFoundError): + with raises_with_cause(RuntimeError, None, ImportError, None): await LocalCUDACluster( n_workers=1, pre_import="my_module",