diff --git a/distributed/client.py b/distributed/client.py index 6ab04f9b5ce..ce1dfa463de 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4531,15 +4531,25 @@ def register_scheduler_plugin(self, plugin, name=None, **kwargs): Parameters ---------- plugin : SchedulerPlugin - Plugin class or object to pass to the scheduler. + SchedulerPlugin instance to pass to the scheduler. name : str Name for the plugin; if None, a name is taken from the plugin instance or automatically generated if not present. **kwargs : Any - Arguments passed to the Plugin class (if Plugin is an + deprecated; Arguments passed to the Plugin class (if Plugin is an instance kwargs are unused). """ + if isinstance(plugin, type): + warnings.warn( + "Adding plugins by class is deprecated and will be disabled in a " + "future release. Please add plugins by instance instead.", + category=FutureWarning, + ) + # note: plugin is constructed in async def _register_scheduler_plugin + elif kwargs: + raise ValueError("kwargs provided but plugin is already an instance") + if name is None: name = _get_plugin_name(plugin) @@ -4610,7 +4620,7 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs): Parameters ---------- plugin : WorkerPlugin or NannyPlugin - The plugin object to register. + WorkerPlugin or NannyPlugin instance to register. name : str, optional A name for the plugin. Registering a plugin with the same name will have no effect. @@ -4618,8 +4628,9 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs): nanny : bool, optional Whether to register the plugin with workers or nannies. **kwargs : optional - If you pass a class as the plugin, instead of a class instance, then the - class will be instantiated with any extra keyword arguments. + Deprecated; If you pass a class as the plugin, instead of a class + instance, then the class will be instantiated with any extra + keyword arguments. Examples -------- @@ -4655,7 +4666,14 @@ class will be instantiated with any extra keyword arguments. unregister_worker_plugin """ if isinstance(plugin, type): + warnings.warn( + "Adding plugins by class is deprecated and will be disabled in a " + "future release. Please add plugins by instance instead.", + category=FutureWarning, + ) plugin = plugin(**kwargs) + elif kwargs: + raise ValueError("kwargs provided but plugin is already an instance") if name is None: name = _get_plugin_name(plugin) diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 4678765d8d1..f8054af310e 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -170,11 +170,11 @@ def start(self, scheduler): scheduler.foo = "bar" assert not hasattr(s, "foo") - await c.register_scheduler_plugin(Dummy1) + await c.register_scheduler_plugin(Dummy1()) assert s.foo == "bar" with pytest.warns(UserWarning) as w: - await c.register_scheduler_plugin(Dummy1) + await c.register_scheduler_plugin(Dummy1()) assert "Scheduler already contains" in w[0].message.args[0] class Dummy2(SchedulerPlugin): @@ -185,7 +185,7 @@ def start(self, scheduler): n_plugins = len(s.plugins) with pytest.raises(RuntimeError, match="raising in start method"): - await c.register_scheduler_plugin(Dummy2) + await c.register_scheduler_plugin(Dummy2()) # total number of plugins should be unchanged assert n_plugins == len(s.plugins) @@ -198,10 +198,30 @@ def start(self, scheduler): n_plugins = len(s.plugins) with pytest.raises(ValueError) as excinfo: - await c.register_scheduler_plugin(Dummy1) + await c.register_scheduler_plugin(Dummy1()) msg = str(excinfo.value) assert "disallowed from deserializing" in msg assert "distributed.scheduler.pickle" in msg assert n_plugins == len(s.plugins) + + +@gen_cluster(nthreads=[], client=True) +async def test_plugin_class_warns(c, s): + class EmptyPlugin(SchedulerPlugin): + pass + + with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"): + await c.register_scheduler_plugin(EmptyPlugin) + + +@gen_cluster(nthreads=[], client=True) +async def test_unused_kwargs_throws(c, s): + class EmptyPlugin(SchedulerPlugin): + pass + + with pytest.raises( + ValueError, match=r"kwargs provided but plugin is already an instance" + ): + await c.register_scheduler_plugin(EmptyPlugin(), data=789) diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index c6dfd1d561d..f740c8d3fcc 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -85,17 +85,39 @@ async def test_remove_with_client_raises(c, s): @gen_cluster(client=True, nthreads=[]) async def test_create_with_client_and_plugin_from_class(c, s): - await c.register_worker_plugin(MyPlugin, data=456) + with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"): + await c.register_worker_plugin(MyPlugin, data=456) worker = await Worker(s.address, loop=s.loop) assert worker._my_plugin_status == "setup" assert worker._my_plugin_data == 456 # Give the plugin a new name so that it registers - await c.register_worker_plugin(MyPlugin, name="new", data=789) + with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"): + await c.register_worker_plugin(MyPlugin, data=789, name="new") assert worker._my_plugin_data == 789 +@gen_cluster(nthreads=[], client=True) +async def test_plugin_class_warns(c, s): + class EmptyPlugin: + pass + + with pytest.warns(FutureWarning, match=r"Adding plugins by class is deprecated"): + await c.register_worker_plugin(EmptyPlugin) + + +@gen_cluster(nthreads=[], client=True) +async def test_unused_kwargs_throws(c, s): + class EmptyPlugin: + pass + + with pytest.raises( + ValueError, match=r"kwargs provided but plugin is already an instance" + ): + await c.register_worker_plugin(EmptyPlugin(), data=789) + + @gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]}) async def test_create_on_construction(c, s, a, b): assert len(a.plugins) == len(b.plugins) == 1 @@ -264,7 +286,7 @@ def transition(self, *args, **kwargs): def teardown(self, worker): del self.worker.foo - await c.register_worker_plugin(MyCustomPlugin) + await c.register_worker_plugin(MyCustomPlugin()) assert w.foo == 0 @@ -287,7 +309,7 @@ def transition(self, *args, **kwargs): def teardown(self, worker): del self.worker.bar - await c.register_worker_plugin(MyCustomPlugin) + await c.register_worker_plugin(MyCustomPlugin()) assert not hasattr(w, "foo") assert w.bar == 0