Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/upstream/main' into shuffl…
Browse files Browse the repository at this point in the history
…e-p2p
  • Loading branch information
phofl committed Aug 7, 2024
2 parents f727033 + 92fc0e2 commit 50ea288
Show file tree
Hide file tree
Showing 12 changed files with 679 additions and 272 deletions.
7 changes: 4 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,7 +3362,9 @@ def _graph_to_futures(
warnings.warn(
f"Sending large graph of size {format_bytes(pickled_size)}.\n"
"This may cause some slowdown.\n"
"Consider scattering data ahead of time and using futures."
"Consider loading the data with Dask directly\n or using futures or "
"delayed objects to embed the data into the graph without repetition.\n"
"See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information."
)

computations = self._get_computation_code(
Expand Down Expand Up @@ -5409,8 +5411,7 @@ async def _unregister_worker_plugin(self, name, nanny=None):

for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
tb = response["traceback"]
_, exc, tb = clean_exception(**response)
raise exc.with_traceback(tb)
return responses

Expand Down
59 changes: 58 additions & 1 deletion distributed/diagnostics/tests/test_nanny_plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging

import pytest

from distributed import Nanny, NannyPlugin
from distributed.protocol.pickle import dumps
from distributed.utils_test import gen_cluster
from distributed.utils_test import captured_logger, gen_cluster


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
Expand Down Expand Up @@ -160,3 +162,58 @@ def setup(self, nanny):
await c.register_plugin(second, idempotent=True)
assert "idempotentplugin" in a.plugins
assert a.plugins["idempotentplugin"].instance == "first"


class BrokenSetupPlugin(NannyPlugin):
def setup(self, nanny):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_register_plugin_with_broken_setup_to_existing_nannies_raises(c, s, a):
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.nanny", level=logging.ERROR) as caplog:
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_plugin_with_broken_setup_on_new_nanny_logs(c, s):
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")

with captured_logger("distributed.nanny", level=logging.ERROR) as caplog:
async with Nanny(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


class BrokenTeardownPlugin(NannyPlugin):
def teardown(self, nanny):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_unregister_nanny_plugin_with_broken_teardown_raises(c, s, a):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.nanny", level=logging.ERROR) as caplog:
await c.unregister_worker_plugin("TestPlugin1", nanny=True)
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")

with captured_logger("distributed.nanny", level=logging.ERROR) as caplog:
async with Nanny(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs
58 changes: 57 additions & 1 deletion distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import asyncio
import logging
import warnings

import pytest

from distributed import Worker, WorkerPlugin
from distributed.protocol.pickle import dumps
from distributed.utils_test import async_poll_for, gen_cluster, inc
from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc


class MyPlugin(WorkerPlugin):
Expand Down Expand Up @@ -423,3 +424,58 @@ def setup(self, worker):
await c.register_plugin(second, idempotent=True)
assert "idempotentplugin" in a.plugins
assert a.plugins["idempotentplugin"].instance == "first"


class BrokenSetupPlugin(WorkerPlugin):
def setup(self, worker):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_register_plugin_with_broken_setup_to_existing_workers_raises(c, s, a):
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_plugin_with_broken_setup_on_new_worker_logs(c, s):
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")

with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
async with Worker(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


class BrokenTeardownPlugin(WorkerPlugin):
def teardown(self, worker):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_unregister_worker_plugin_with_broken_teardown_raises(c, s, a):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
await c.unregister_worker_plugin("TestPlugin1")
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_plugin_with_broken_teardown_logs_on_close(c, s):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")

with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
async with Worker(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs
11 changes: 5 additions & 6 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ properties:
- string
- "null"
description: |
Shut down the scheduler after this duration if there are pending tasks,
but no workers that can process them. This can either mean that there are
no workers running at all, or that there are idle workers but they've been
excluded through worker or resource restrictions.
Timeout for tasks in an unrunnable state.
If task remains unrunnable for longer than this, it fails. A task is considered unrunnable IFF
it has no pending dependencies, and the task has restrictions that are not satisfied by
any available worker or no workers are running at all.
In adaptive clusters, this timeout must be set to be safely higher than
the time it takes for workers to spin up.
Works in conjunction with idle-timeout.
work-stealing:
type: boolean
description: |
Expand Down
2 changes: 1 addition & 1 deletion distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ distributed:
# after they have been removed from the scheduler
events-cleanup-delay: 1h
idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes"
no-workers-timeout: null # Shut down if there are tasks but no workers to process them
no-workers-timeout: null # If a task remains unrunnable for longer than this, it fails.
work-stealing: True # workers should steal tasks from each other
work-stealing-interval: 100ms # Callback time for work stealing
worker-saturation: 1.1 # Send this fraction of nthreads root tasks to workers
Expand Down
12 changes: 4 additions & 8 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,14 @@ async def plugin_add(

self.plugins[name] = plugin

logger.info("Starting Nanny plugin %s" % name)
logger.info("Starting Nanny plugin %s", name)
if hasattr(plugin, "setup"):
try:
result = plugin.setup(nanny=self)
if isawaitable(result):
result = await result
except Exception as e:
logger.exception("Nanny plugin %s failed to setup", name)
return error_message(e)
if getattr(plugin, "restart", False):
await self.restart(reason=f"nanny-plugin-{name}-restart")
Expand All @@ -500,6 +501,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage:
if isawaitable(result):
result = await result
except Exception as e:
logger.exception("Nanny plugin %s failed to teardown", name)
msg = error_message(e)
return msg

Expand Down Expand Up @@ -610,13 +612,7 @@ async def close( # type:ignore[override]

await self.preloads.teardown()

teardowns = [
plugin.teardown(self)
for plugin in self.plugins.values()
if hasattr(plugin, "teardown")
]

await asyncio.gather(*(td for td in teardowns if isawaitable(td)))
await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins))

self.stop()
if self.process is not None:
Expand Down
Loading

0 comments on commit 50ea288

Please sign in to comment.