Skip to content

Commit

Permalink
deprecate passing plugin classes to register_worker_plugin and regist…
Browse files Browse the repository at this point in the history
…er_scheduler_plugin

fail if unused kwargs are passed

Fixes #5698
  • Loading branch information
graingert committed Jan 26, 2022
1 parent 682a7b1 commit d15980c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
27 changes: 22 additions & 5 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4512,15 +4512,24 @@ def register_scheduler_plugin(self, plugin, name=None, **kwargs):
Parameters
----------
plugin : SchedulerPlugin
Plugin class or object to pass to the scheduler.
SchedulerPlugin instance to pass to the scheduler.
name : str
Name for the plugin; if None, a name is taken from the
plugin instance or automatically generated if not present.
**kwargs : Any
Arguments passed to the Plugin class (if Plugin is an
deprecated; Arguments passed to the Plugin class (if Plugin is an
instance kwargs are unused).
"""
if isinstance(plugin, type):
warnings.warn(
"Adding plugins by class is deprecated and will be disabled in a "
"future release. Please add plugins by instance instead.",
category=FutureWarning,
)
elif kwargs:
raise ValueError("kwargs provided but plugin is already an instance")

if name is None:
name = _get_plugin_name(plugin)

Expand Down Expand Up @@ -4591,16 +4600,17 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs):
Parameters
----------
plugin : WorkerPlugin or NannyPlugin
The plugin object to register.
WorkerPlugin or NannyPlugin instance to register.
name : str, optional
A name for the plugin.
Registering a plugin with the same name will have no effect.
If plugin has no name attribute a random name is used.
nanny : bool, optional
Whether to register the plugin with workers or nannies.
**kwargs : optional
If you pass a class as the plugin, instead of a class instance, then the
class will be instantiated with any extra keyword arguments.
Deprecated; If you pass a class as the plugin, instead of a class
instance, then the class will be instantiated with any extra
keyword arguments.
Examples
--------
Expand Down Expand Up @@ -4636,7 +4646,14 @@ class will be instantiated with any extra keyword arguments.
unregister_worker_plugin
"""
if isinstance(plugin, type):
warnings.warn(
"Adding plugins by class is deprecated and will be disabled in a "
"future release. Please add plugins by instance instead.",
category=FutureWarning,
)
plugin = plugin(**kwargs)
elif kwargs:
raise ValueError("kwargs provided but plugin is already an instance")

if name is None:
name = _get_plugin_name(plugin)
Expand Down
28 changes: 24 additions & 4 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def start(self, scheduler):
scheduler.foo = "bar"

assert not hasattr(s, "foo")
await c.register_scheduler_plugin(Dummy1)
await c.register_scheduler_plugin(Dummy1())
assert s.foo == "bar"

with pytest.warns(UserWarning) as w:
await c.register_scheduler_plugin(Dummy1)
await c.register_scheduler_plugin(Dummy1())
assert "Scheduler already contains" in w[0].message.args[0]

class Dummy2(SchedulerPlugin):
Expand All @@ -185,7 +185,7 @@ def start(self, scheduler):

n_plugins = len(s.plugins)
with pytest.raises(RuntimeError, match="raising in start method"):
await c.register_scheduler_plugin(Dummy2)
await c.register_scheduler_plugin(Dummy2())
# total number of plugins should be unchanged
assert n_plugins == len(s.plugins)

Expand All @@ -198,10 +198,30 @@ def start(self, scheduler):

n_plugins = len(s.plugins)
with pytest.raises(ValueError) as excinfo:
await c.register_scheduler_plugin(Dummy1)
await c.register_scheduler_plugin(Dummy1())

msg = str(excinfo.value)
assert "disallowed from deserializing" in msg
assert "distributed.scheduler.pickle" in msg

assert n_plugins == len(s.plugins)


@gen_cluster(nthreads=[], client=True)
async def test_plugin_class_warns(c, s):
class EmptyPlugin(SchedulerPlugin):
pass

with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"):
await c.register_scheduler_plugin(EmptyPlugin)


@gen_cluster(nthreads=[], client=True)
async def test_unused_kwargs_throws(c, s):
class EmptyPlugin(SchedulerPlugin):
pass

with pytest.raises(
ValueError, match=r"kwargs provided but plugin is already an instance"
):
await c.register_scheduler_plugin(EmptyPlugin(), data=789)
30 changes: 26 additions & 4 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,39 @@ async def test_remove_with_client_raises(c, s):

@gen_cluster(client=True, nthreads=[])
async def test_create_with_client_and_plugin_from_class(c, s):
await c.register_worker_plugin(MyPlugin, data=456)
with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"):
await c.register_worker_plugin(MyPlugin, data=456)

worker = await Worker(s.address, loop=s.loop)
assert worker._my_plugin_status == "setup"
assert worker._my_plugin_data == 456

# Give the plugin a new name so that it registers
await c.register_worker_plugin(MyPlugin, name="new", data=789)
with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"):
await c.register_worker_plugin(MyPlugin, data=789, name="new")
assert worker._my_plugin_data == 789


@gen_cluster(nthreads=[], client=True)
async def test_plugin_class_warns(c, s):
class EmptyPlugin:
pass

with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"):
await c.register_worker_plugin(EmptyPlugin)


@gen_cluster(nthreads=[], client=True)
async def test_unused_kwargs_throws(c, s):
class EmptyPlugin:
pass

with pytest.raises(
ValueError, match=r"kwargs provided but plugin is already an instance"
):
await c.register_worker_plugin(EmptyPlugin(), data=789)


@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
async def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1
Expand Down Expand Up @@ -264,7 +286,7 @@ def transition(self, *args, **kwargs):
def teardown(self, worker):
del self.worker.foo

await c.register_worker_plugin(MyCustomPlugin)
await c.register_worker_plugin(MyCustomPlugin())

assert w.foo == 0

Expand All @@ -287,7 +309,7 @@ def transition(self, *args, **kwargs):
def teardown(self, worker):
del self.worker.bar

await c.register_worker_plugin(MyCustomPlugin)
await c.register_worker_plugin(MyCustomPlugin())

assert not hasattr(w, "foo")
assert w.bar == 0
Expand Down

0 comments on commit d15980c

Please sign in to comment.