From bc85bb0ffa0aa192dd09b338c967a5a623b3fca3 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 8 Feb 2021 11:16:09 -0600 Subject: [PATCH 1/3] Attempt to get client from worker in Queue and Variable --- distributed/queues.py | 12 ++++++++---- distributed/variable.py | 14 +++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/distributed/queues.py b/distributed/queues.py index d022b010e1c..e368d329d03 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -7,7 +7,7 @@ from .client import Future, Client from .utils import sync, thread_state -from .worker import get_client +from .worker import get_client, get_worker from .utils import parse_timedelta logger = logging.getLogger(__name__) @@ -150,8 +150,8 @@ class Queue: Name used by other clients and the scheduler to identify the queue. If not given, a random name will be generated. client: Client (optional) - Client used for communication with the scheduler. Defaults to the - value of ``Client.current()``. + Client used for communication with the scheduler. + If not given, the default global client will be used. maxsize: int (optional) Number of items allowed in the queue. If 0 (the default), the queue size is unbounded. @@ -170,7 +170,11 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or Client.current() + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "queue-" + uuid.uuid4().hex self._event_started = asyncio.Event() if self.client.asynchronous or getattr( diff --git a/distributed/variable.py b/distributed/variable.py index 19fbd2bb031..c3fdc94d0d7 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -9,13 +9,13 @@ from dask.utils import stringify from .client import Future, Client from .utils import log_errors, TimeoutError, parse_timedelta -from .worker import get_client +from .worker import get_client, get_worker logger = logging.getLogger(__name__) class VariableExtension: - """An extension for the scheduler to manage queues + """An extension for the scheduler to manage Variables This adds the following routes to the scheduler @@ -145,8 +145,8 @@ class Variable: Name used by other clients and the scheduler to identify the variable. If not given, a random name will be generated. client: Client (optional) - Client used for communication with the scheduler. Defaults to the - value of ``Client.current()``. + Client used for communication with the scheduler. + If not given, the default global client will be used. Examples -------- @@ -165,7 +165,11 @@ class Variable: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or Client.current() + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "variable-" + uuid.uuid4().hex async def _set(self, value): From 8fedb7dfa2bbc694852a85611f9ca90ecd76cd60 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 8 Feb 2021 13:22:47 -0600 Subject: [PATCH 2/3] Add tests --- distributed/tests/test_queues.py | 18 ++++++++++++++++++ distributed/tests/test_variable.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 94d80c9dbcf..3431b5c6b29 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -5,6 +5,7 @@ import pytest from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError +from distributed.client import _del_global_client from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -276,3 +277,20 @@ def get(): res = c.submit(get) await c.gather([res, fut]) + + +@gen_cluster(client=True) +async def test_queue_in_task(c, s, a, b): + x = await Queue("x") + await x.put(123) + + def foo(): + y = Queue("x") + return y.get() + + # We want to make sure Client.current() will not return c + # when called from inside a task + _del_global_client(c) + + result = await c.submit(foo) + assert result == 123 diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 5d9ece6ee54..521f8867522 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -9,6 +9,7 @@ from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError from distributed.metrics import time +from distributed.client import _del_global_client from distributed.compatibility import WINDOWS from distributed.utils_test import gen_cluster, inc, div from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -40,6 +41,23 @@ async def test_variable(c, s, a, b): assert time() < start + 5 +@gen_cluster(client=True) +async def test_variable_in_task(c, s, a, b): + x = Variable("x") + await x.set(123) + + def foo(): + y = Variable("x") + return y.get() + + # We want to make sure Client.current() will not return c + # when called from inside a task + _del_global_client(c) + + result = await c.submit(foo) + assert result == 123 + + @gen_cluster(client=True) async def test_delete_unset_variable(c, s, a, b): x = Variable() From 0584c0523a96935c6a1af22d25aa3e79e41cef75 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 8 Feb 2021 14:43:30 -0600 Subject: [PATCH 3/3] Update tests --- distributed/tests/test_queues.py | 29 +++++++++++++++-------------- distributed/tests/test_variable.py | 30 +++++++++++++++--------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 3431b5c6b29..8f400498854 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -5,9 +5,8 @@ import pytest from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError -from distributed.client import _del_global_client from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, div +from distributed.utils_test import gen_cluster, inc, div, popen from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -279,18 +278,20 @@ def get(): await c.gather([res, fut]) -@gen_cluster(client=True) -async def test_queue_in_task(c, s, a, b): - x = await Queue("x") - await x.put(123) +def test_queue_in_task(loop): + # Ensure that we can create a Queue inside a task on a + # worker in a separate Python process than the client + with popen(["dask-scheduler", "--no-dashboard"]): + with popen(["dask-worker", "127.0.0.1:8786"]): + with Client("tcp://127.0.0.1:8786", loop=loop) as c: + c.wait_for_workers(1) - def foo(): - y = Queue("x") - return y.get() + x = Queue("x") + x.put(123) - # We want to make sure Client.current() will not return c - # when called from inside a task - _del_global_client(c) + def foo(): + y = Queue("x") + return y.get() - result = await c.submit(foo) - assert result == 123 + result = c.submit(foo).result() + assert result == 123 diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 521f8867522..37b3c756be7 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -9,11 +9,9 @@ from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError from distributed.metrics import time -from distributed.client import _del_global_client from distributed.compatibility import WINDOWS -from distributed.utils_test import gen_cluster, inc, div +from distributed.utils_test import gen_cluster, inc, div, captured_logger, popen from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 -from distributed.utils_test import captured_logger @gen_cluster(client=True) @@ -41,21 +39,23 @@ async def test_variable(c, s, a, b): assert time() < start + 5 -@gen_cluster(client=True) -async def test_variable_in_task(c, s, a, b): - x = Variable("x") - await x.set(123) +def test_variable_in_task(loop): + # Ensure that we can create a Variable inside a task on a + # worker in a separate Python process than the client + with popen(["dask-scheduler", "--no-dashboard"]): + with popen(["dask-worker", "127.0.0.1:8786"]): + with Client("tcp://127.0.0.1:8786", loop=loop) as c: + c.wait_for_workers(1) - def foo(): - y = Variable("x") - return y.get() + x = Variable("x") + x.set(123) - # We want to make sure Client.current() will not return c - # when called from inside a task - _del_global_client(c) + def foo(): + y = Variable("x") + return y.get() - result = await c.submit(foo) - assert result == 123 + result = c.submit(foo).result() + assert result == 123 @gen_cluster(client=True)