diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 624815e75..d1024ff69 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -11,7 +11,7 @@ from dask import dataframe as dd from dask.dataframe.shuffle import partitioning_index from dask.dataframe.utils import assert_eq -from distributed import Client, get_worker +from distributed import Client from distributed.deploy.local import LocalCluster import dask_cuda @@ -314,8 +314,8 @@ def test_jit_unspill(protocol): def _test_lock_workers(scheduler_address, ranks): - async def f(_): - worker = get_worker() + async def f(info): + worker = info["worker"] if hasattr(worker, "running"): assert not worker.running worker.running = True diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index a72ec3f2e..f2e48783c 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -9,7 +9,6 @@ from dask.distributed import Client from distributed.system import MEMORY_LIMIT from distributed.utils_test import gen_test, raises_with_cause -from distributed.worker import get_worker from dask_cuda import CUDAWorker, LocalCUDACluster, utils from dask_cuda.initialize import initialize @@ -140,7 +139,9 @@ async def test_no_memory_limits_cluster(): ) as cluster: async with Client(cluster, asynchronous=True) as client: # Check that all workers use a regular dict as their "data store". - res = await client.run(lambda: isinstance(get_worker().data, dict)) + res = await client.run( + lambda dask_worker: isinstance(dask_worker.data, dict) + ) assert all(res.values()) @@ -161,7 +162,9 @@ async def test_no_memory_limits_cudaworker(): await new_worker await client.wait_for_workers(2) # Check that all workers use a regular dict as their "data store". - res = await client.run(lambda: isinstance(get_worker().data, dict)) + res = await client.run( + lambda dask_worker: isinstance(dask_worker.data, dict) + ) assert all(res.values()) await new_worker.close() diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index 41399d673..50b2c51a5 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -12,7 +12,6 @@ from dask.utils import format_bytes from distributed import Client from distributed.utils_test import gen_test -from distributed.worker import get_worker import dask_cuda import dask_cuda.proxify_device_objects @@ -429,9 +428,9 @@ async def test_worker_force_spill_to_disk(): ddf = dask.dataframe.from_pandas(df, npartitions=1).persist() await ddf - async def f(): + async def f(dask_worker): """Trigger a memory_monitor() and reset memory_limit""" - w = get_worker() + w = dask_worker # Set a host memory limit that triggers spilling to disk w.memory_manager.memory_pause_fraction = False memory = w.monitor.proc.memory_info().rss @@ -443,7 +442,7 @@ async def f(): assert w.monitor.proc.memory_info().rss < memory - 10**7 w.memory_manager.memory_limit = memory * 10 # Un-limit - await client.submit(f) + client.run(f) log = str(await client.get_worker_logs()) # Check that the worker doesn't complain about unmanaged memory assert "Unmanaged memory use is high" not in log diff --git a/dask_cuda/tests/test_spill.py b/dask_cuda/tests/test_spill.py index f93b83ec7..bbd24d5ad 100644 --- a/dask_cuda/tests/test_spill.py +++ b/dask_cuda/tests/test_spill.py @@ -6,7 +6,7 @@ import dask from dask import array as da -from distributed import Client, get_worker, wait +from distributed import Client, wait from distributed.metrics import time from distributed.sizeof import sizeof from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401 @@ -57,21 +57,25 @@ def assert_device_host_file_size( ) -def worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead): +def worker_assert( + dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead +): assert_device_host_file_size( - get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead ) -def delayed_worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead): +def delayed_worker_assert( + dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead +): start = time() while not device_host_file_size_matches( - get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead ): sleep(0.01) if time() < start + 3: assert_device_host_file_size( - get_worker().data, + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead, @@ -143,17 +147,23 @@ async def test_cupy_cluster_device_spill(params): await wait(xx) # Allow up to 1024 bytes overhead per chunk serialized - await client.run(worker_assert, x.nbytes, 1024, 1024) + await client.run( + lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + ) y = client.compute(x.sum()) res = await y assert (abs(res / x.size) - 0.5) < 1e-3 - await client.run(worker_assert, x.nbytes, 1024, 1024) - host_chunks = await client.run(lambda: len(get_worker().data.host)) + await client.run( + lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + ) + host_chunks = await client.run( + lambda dask_worker: len(dask_worker.data.host) + ) disk_chunks = await client.run( - lambda: len(get_worker().data.disk or list()) + lambda dask_worker: len(dask_worker.data.disk or list()) ) for hc, dc in zip(host_chunks.values(), disk_chunks.values()): if params["spills_to_disk"]: @@ -245,9 +255,11 @@ async def test_cudf_cluster_device_spill(params): del cdf - host_chunks = await client.run(lambda: len(get_worker().data.host)) + host_chunks = await client.run( + lambda dask_worker: len(dask_worker.data.host) + ) disk_chunks = await client.run( - lambda: len(get_worker().data.disk or list()) + lambda dask_worker: len(dask_worker.data.disk or list()) ) for hc, dc in zip(host_chunks.values(), disk_chunks.values()): if params["spills_to_disk"]: @@ -256,8 +268,12 @@ async def test_cudf_cluster_device_spill(params): assert hc > 0 assert dc == 0 - await client.run(worker_assert, nbytes, 32, 2048) + await client.run( + lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048) + ) del cdf2 - await client.run(delayed_worker_assert, 0, 0, 0) + await client.run( + lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0) + )