Skip to content

Commit

Permalink
Actor: don't hold key references on workers
Browse files Browse the repository at this point in the history
Fixes dask#4936

I don't think this is quite the right implementation.
1) Why does the `worker=` kwarg exist? It doesn't seem to be used. But it should be. Taking the `if worker` codepath would bypass this whole issue.
2) What if a user is using an Actor within a task? In that case, `get_worker` would return a Worker, but we _would_ want to hold a reference to the Actor key (as long as that task was running).

I think a better implementation might be to include in `__reduce__` whether or not the Actor handle should be a weakref or not, basically. And in `Worker.get_data`, construct it such that it is a weakref.
  • Loading branch information
gjoseph92 committed Jun 19, 2021
1 parent d9bc3c6 commit 97f9d60
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
8 changes: 4 additions & 4 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import threading
from queue import Queue

from .client import Future, default_client
from .client import Future
from .protocol import to_serialize
from .utils import iscoroutinefunction, sync, thread_state
from .utils_comm import WrappedKey
from .worker import get_worker
from .worker import get_client, get_worker


class Actor(WrappedKey):
Expand Down Expand Up @@ -63,8 +63,8 @@ def __init__(self, cls, address, key, worker=None):
except ValueError:
self._worker = None
try:
self._client = default_client()
self._future = Future(key)
self._client = get_client()
self._future = Future(key, inform=self._worker is None)
except ValueError:
self._client = None

Expand Down
57 changes: 56 additions & 1 deletion distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def check(dask_worker):
start = time()
while any(client.run(check).values()):
sleep(0.01)
assert time() < start + 30
assert time() < start + 10


@gen_cluster(
Expand Down Expand Up @@ -566,6 +566,61 @@ async def wait(self):
await c.gather(futures)


@gen_cluster(client=True, client_kwargs=dict(set_as_default=False))
# ^ NOTE: without `set_as_default=False`, `get_client()` within worker would return
# the same client instance the test is using (because it's all one process).
# Even with this, both workers will share the same client instance.
async def test_worker_actor_handle_is_weakref(c, s, a, b):
counter = c.submit(Counter, actor=True, workers=[a.address])

await c.submit(lambda _: None, counter, workers=[b.address])

del counter

start = time()
while a.actors or b.data:
await asyncio.sleep(0.1)
assert time() < start + 10


def test_worker_actor_handle_is_weakref_sync(client):
workers = list(client.run(lambda: None))
counter = client.submit(Counter, actor=True, workers=[workers[0]])

client.submit(lambda _: None, counter, workers=[workers[1]]).result()

del counter

def check(dask_worker):
return len(dask_worker.data) + len(dask_worker.actors)

start = time()
while any(client.run(check).values()):
sleep(0.01)
assert time() < start + 10


def test_worker_actor_handle_is_weakref_from_compute_sync(client):
workers = list(client.run(lambda: None))

with dask.annotate(workers=workers[0]):
counter = dask.delayed(Counter)()
with dask.annotate(workers=workers[1]):
intermediate = dask.delayed(lambda c: None)(counter)
with dask.annotate(workers=workers[0]):
final = dask.delayed(lambda x, c: x)(intermediate, counter)

final.compute(actors=counter, optimize_graph=False)

def worker_tasks_running(dask_worker):
return len(dask_worker.data) + len(dask_worker.actors)

start = time()
while any(client.run(worker_tasks_running).values()):
sleep(0.01)
assert time() < start + 10


def test_one_thread_deadlock():
with cluster(nworkers=2) as (cl, w):
client = Client(cl["address"])
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ async def get_data(
if k in self.actors:
from .actor import Actor

data[k] = Actor(type(self.actors[k]), self.address, k)
data[k] = Actor(type(self.actors[k]), self.address, k, worker=self)

msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}}
nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks}
Expand Down

0 comments on commit 97f9d60

Please sign in to comment.