From 97f9d60fd3889248069507c52f7deb51ddfab4f4 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 18 Jun 2021 21:27:00 -0600 Subject: [PATCH] Actor: don't hold key references on workers Fixes #4936 I don't think this is quite the right implementation. 1) Why does the `worker=` kwarg exist? It doesn't seem to be used. But it should be. Taking the `if worker` codepath would bypass this whole issue. 2) What if a user is using an Actor within a task? In that case, `get_worker` would return a Worker, but we _would_ want to hold a reference to the Actor key (as long as that task was running). I think a better implementation might be to include in `__reduce__` whether or not the Actor handle should be a weakref or not, basically. And in `Worker.get_data`, construct it such that it is a weakref. --- distributed/actor.py | 8 ++--- distributed/tests/test_actor.py | 57 ++++++++++++++++++++++++++++++++- distributed/worker.py | 2 +- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 77b2cda67de..231cc8b3a2d 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -3,11 +3,11 @@ import threading from queue import Queue -from .client import Future, default_client +from .client import Future from .protocol import to_serialize from .utils import iscoroutinefunction, sync, thread_state from .utils_comm import WrappedKey -from .worker import get_worker +from .worker import get_client, get_worker class Actor(WrappedKey): @@ -63,8 +63,8 @@ def __init__(self, cls, address, key, worker=None): except ValueError: self._worker = None try: - self._client = default_client() - self._future = Future(key) + self._client = get_client() + self._future = Future(key, inform=self._worker is None) except ValueError: self._client = None diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 851ee7e8b2a..87c126bf2f9 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -515,7 +515,7 @@ def check(dask_worker): start = time() while any(client.run(check).values()): sleep(0.01) - assert time() < start + 30 + assert time() < start + 10 @gen_cluster( @@ -566,6 +566,61 @@ async def wait(self): await c.gather(futures) +@gen_cluster(client=True, client_kwargs=dict(set_as_default=False)) +# ^ NOTE: without `set_as_default=False`, `get_client()` within worker would return +# the same client instance the test is using (because it's all one process). +# Even with this, both workers will share the same client instance. +async def test_worker_actor_handle_is_weakref(c, s, a, b): + counter = c.submit(Counter, actor=True, workers=[a.address]) + + await c.submit(lambda _: None, counter, workers=[b.address]) + + del counter + + start = time() + while a.actors or b.data: + await asyncio.sleep(0.1) + assert time() < start + 10 + + +def test_worker_actor_handle_is_weakref_sync(client): + workers = list(client.run(lambda: None)) + counter = client.submit(Counter, actor=True, workers=[workers[0]]) + + client.submit(lambda _: None, counter, workers=[workers[1]]).result() + + del counter + + def check(dask_worker): + return len(dask_worker.data) + len(dask_worker.actors) + + start = time() + while any(client.run(check).values()): + sleep(0.01) + assert time() < start + 10 + + +def test_worker_actor_handle_is_weakref_from_compute_sync(client): + workers = list(client.run(lambda: None)) + + with dask.annotate(workers=workers[0]): + counter = dask.delayed(Counter)() + with dask.annotate(workers=workers[1]): + intermediate = dask.delayed(lambda c: None)(counter) + with dask.annotate(workers=workers[0]): + final = dask.delayed(lambda x, c: x)(intermediate, counter) + + final.compute(actors=counter, optimize_graph=False) + + def worker_tasks_running(dask_worker): + return len(dask_worker.data) + len(dask_worker.actors) + + start = time() + while any(client.run(worker_tasks_running).values()): + sleep(0.01) + assert time() < start + 10 + + def test_one_thread_deadlock(): with cluster(nworkers=2) as (cl, w): client = Client(cl["address"]) diff --git a/distributed/worker.py b/distributed/worker.py index bd1bf6f0d75..44e05f05024 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1417,7 +1417,7 @@ async def get_data( if k in self.actors: from .actor import Actor - data[k] = Actor(type(self.actors[k]), self.address, k) + data[k] = Actor(type(self.actors[k]), self.address, k, worker=self) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks}