From 66a9c3dc53d180be70d8138abc96d145a2f1ef78 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 8 Jun 2022 14:28:15 +0200 Subject: [PATCH 1/9] Prefer current over default Client --- distributed/client.py | 16 +++++++++++++--- distributed/diagnostics/progressbar.py | 4 ++-- distributed/utils_test.py | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7bc4e1f7913..43a2a3b5e95 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1447,10 +1447,12 @@ def _heartbeat(self): def __enter__(self): if not self._loop_runner.is_started(): self.start() + self._previous_as_current = _current_client.set(self) return self async def __aenter__(self): await self + self._previous_as_current = _current_client.set(self) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -1460,9 +1462,17 @@ async def __aexit__(self, exc_type, exc_value, traceback): fast=exc_type is not None ) + try: + _current_client.reset(self._previous_as_current) + except ValueError: + raise RuntimeError("Closed Clients in wrong order") def __exit__(self, exc_type, exc_value, traceback): self.close() + try: + _current_client.reset(self._previous_as_current) + except ValueError: + raise RuntimeError("Closed Clients in wrong order") def __del__(self): # If the loop never got assigned, we failed early in the constructor, @@ -4900,7 +4910,7 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): """ if timeout is not None and isinstance(timeout, (Number, str)): timeout = parse_timedelta(timeout, default="s") - client = default_client() + client = Client.current() result = client.sync(_wait, fs, timeout=timeout, return_when=return_when) return result @@ -4997,7 +5007,7 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru self.futures = defaultdict(lambda: 0) self.queue = pyQueue() self.lock = threading.Lock() - self.loop = loop or default_client().loop + self.loop = loop or Client.current().loop self.thread_condition = threading.Condition() self.with_results = with_results self.raise_errors = raise_errors @@ -5538,7 +5548,7 @@ def temp_default_client(c): c : Client This is what default_client() will return within the with-block. """ - old_exec = default_client() + old_exec = Client.current() _set_global_client(c) try: yield diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index 828b2e3f4d0..a394d4e6ede 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -14,7 +14,7 @@ import dask from dask.utils import key_split -from distributed.client import default_client, futures_of +from distributed.client import Client, futures_of from distributed.core import ( CommClosedError, clean_exception, @@ -30,7 +30,7 @@ def get_scheduler(scheduler): if scheduler is None: - return default_client().scheduler.address + return Client.current().scheduler.address return coerce_to_address(scheduler) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 9213916f563..610dbab446b 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -691,7 +691,7 @@ async def wait_for_workers(): for w in workers_by_pid.values() ] try: - client = default_client() + client = Client.current() except ValueError: pass else: From 5046fb95bf446aa6eb96c832dee7b57b99861cbe Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 14 Dec 2022 15:15:13 +0100 Subject: [PATCH 2/9] only set current if default is allowed --- distributed/client.py | 33 ++++++++++++++++---------- distributed/diagnostics/progressbar.py | 4 ++-- distributed/tests/test_client.py | 24 ++++++++++--------- distributed/utils_test.py | 2 +- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 43a2a3b5e95..f049abf60cc 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -138,6 +138,7 @@ def _get_global_client() -> Client | None: def _set_global_client(c: Client | None) -> None: if c is not None: + c._set_as_default = True _global_clients[_global_client_index[0]] = c _global_client_index[0] += 1 @@ -867,6 +868,7 @@ def __init__( deserializers = serializers self._deserializers = deserializers self.direct_to_workers = direct_to_workers + self._previous_as_current = None # Communication self.scheduler_comm = None @@ -1061,6 +1063,10 @@ def current(cls, allow_global=True): ------ ValueError If there is no client set, a ValueError is raised + + See also + -------- + default_client """ out = _current_client.get() if out: @@ -1447,12 +1453,14 @@ def _heartbeat(self): def __enter__(self): if not self._loop_runner.is_started(): self.start() - self._previous_as_current = _current_client.set(self) + if self._set_as_default: + self._previous_as_current = _current_client.set(self) return self async def __aenter__(self): await self - self._previous_as_current = _current_client.set(self) + if self._set_as_default: + self._previous_as_current = _current_client.set(self) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -1462,17 +1470,13 @@ async def __aexit__(self, exc_type, exc_value, traceback): fast=exc_type is not None ) - try: + if self._previous_as_current: _current_client.reset(self._previous_as_current) - except ValueError: - raise RuntimeError("Closed Clients in wrong order") def __exit__(self, exc_type, exc_value, traceback): self.close() - try: + if self._previous_as_current: _current_client.reset(self._previous_as_current) - except ValueError: - raise RuntimeError("Closed Clients in wrong order") def __del__(self): # If the loop never got assigned, we failed early in the constructor, @@ -4910,7 +4914,7 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): """ if timeout is not None and isinstance(timeout, (Number, str)): timeout = parse_timedelta(timeout, default="s") - client = Client.current() + client = default_client() result = client.sync(_wait, fs, timeout=timeout, return_when=return_when) return result @@ -5007,7 +5011,7 @@ def __init__(self, futures=None, loop=None, with_results=False, raise_errors=Tru self.futures = defaultdict(lambda: 0) self.queue = pyQueue() self.lock = threading.Lock() - self.loop = loop or Client.current().loop + self.loop = loop or default_client().loop self.thread_condition = threading.Condition() self.with_results = with_results self.raise_errors = raise_errors @@ -5207,6 +5211,10 @@ def default_client(c=None): ------- c : Client The client, if one has started + + See also + -------- + Client.current (alias) """ c = c or _get_global_client() if c: @@ -5548,10 +5556,11 @@ def temp_default_client(c): c : Client This is what default_client() will return within the with-block. """ - old_exec = Client.current() + old_exec = default_client() _set_global_client(c) try: - yield + with c.as_current(): + yield finally: _set_global_client(old_exec) diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index a394d4e6ede..828b2e3f4d0 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -14,7 +14,7 @@ import dask from dask.utils import key_split -from distributed.client import Client, futures_of +from distributed.client import default_client, futures_of from distributed.core import ( CommClosedError, clean_exception, @@ -30,7 +30,7 @@ def get_scheduler(scheduler): if scheduler is None: - return Client.current().scheduler.address + return default_client().scheduler.address return coerce_to_address(scheduler) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8f85ab5d297..341bfd2eb26 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3365,16 +3365,20 @@ def test_default_get(loop_in_thread): assert dask.base.get_scheduler() == pre_get -@gen_cluster(client=True) -async def test_ensure_default_client(c, s, a, b): - assert c is default_client() - - async with Client(s.address, set_as_default=False, asynchronous=True) as c2: +@gen_cluster() +async def test_ensure_default_client(s, a, b): + c = await Client(s.address, asynchronous=True) + try: assert c is default_client() - assert c2 is not default_client() - ensure_default_client(c2) - assert c is not default_client() - assert c2 is default_client() + + async with Client(s.address, set_as_default=False, asynchronous=True) as c2: + assert c is default_client() + assert c2 is not default_client() + ensure_default_client(c2) + assert c is not default_client() + assert c2 is default_client() + finally: + await c.close() @gen_cluster() @@ -4072,8 +4076,6 @@ async def test_as_current(c, s, a, b): with temp_default_client(c): assert Client.current() is c - with pytest.raises(ValueError): - Client.current(allow_global=False) with c1.as_current(): assert Client.current() is c1 assert Client.current(allow_global=True) is c1 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 610dbab446b..9213916f563 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -691,7 +691,7 @@ async def wait_for_workers(): for w in workers_by_pid.values() ] try: - client = Client.current() + client = default_client() except ValueError: pass else: From d1f8191b0f8f57eaf0bbb4e7b9531229c0c696cf Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 14:30:00 +0100 Subject: [PATCH 3/9] fix test_unpickle_without_client tests --- distributed/tests/test_events.py | 10 +++++----- distributed/tests/test_queues.py | 10 +++++----- distributed/tests/test_semaphore.py | 10 +++++----- distributed/tests/test_variable.py | 10 +++++----- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/distributed/tests/test_events.py b/distributed/tests/test_events.py index 5d6f4bbe843..b26df7ad399 100644 --- a/distributed/tests/test_events.py +++ b/distributed/tests/test_events.py @@ -224,15 +224,15 @@ def event_is_set(event_name): assert not s.extensions["events"]._waiter_count -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - obj = await Event() - pickled = pickle.dumps(obj) - await c.close() + async with Client(s.address, asynchronous=True) as c: + obj = await Event() + pickled = pickle.dumps(obj) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index ef92caa77bd..8a0dd96ca9b 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -303,15 +303,15 @@ def foo(): assert result == 123 -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - q = await Queue() - pickled = pickle.dumps(q) - await c.close() + async with Client(s.address, asynchronous=True) as c: + q = await Queue() + pickled = pickle.dumps(q) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 025a965396a..0cd98e1613d 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -585,15 +585,15 @@ async def test_release_failure(c, s, a, b, caplog): await pool.close() -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - sem = await Semaphore() - pickled = pickle.dumps(sem) - await c.close() + async with Client(s.address, asynchronous=True) as c: + sem = await Semaphore() + pickled = pickle.dumps(sem) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 2e711e9940a..00618da70d8 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -297,15 +297,15 @@ async def test_variables_do_not_leak_client(c, s, a, b): assert time() < start + 5 -@gen_cluster(client=True, nthreads=[]) -async def test_unpickle_without_client(c, s): +@gen_cluster(nthreads=[]) +async def test_unpickle_without_client(s): """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. This typically happens if the object is being deserialized on the scheduler. """ - obj = Variable("foo") - pickled = pickle.dumps(obj) - await c.close() + async with Client(s.address, asynchronous=True) as c: + obj = Variable("foo") + pickled = pickle.dumps(obj) # We do not want to initialize a client during unpickling with pytest.raises(ValueError): From 3e7ccf5e062eb3133a002a2145cc803f0dd99f38 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 14:39:34 +0100 Subject: [PATCH 4/9] swap reseting cvar and closing client this fixes test_client.py::test_default_get and test_client.py::test_Future_exception_sync_2 --- distributed/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 7d2b46a46c6..ac74ef80caf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1503,19 +1503,19 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): + if self._previous_as_current: + _current_client.reset(self._previous_as_current) await self._close( # if we're handling an exception, we assume that it's more # important to deliver that exception than shutdown gracefully. fast=exc_type is not None ) - if self._previous_as_current: - _current_client.reset(self._previous_as_current) def __exit__(self, exc_type, exc_value, traceback): - self.close() if self._previous_as_current: _current_client.reset(self._previous_as_current) + self.close() def __del__(self): # If the loop never got assigned, we failed early in the constructor, From b12c55893c60b305feeef6cea70f9948e0572457 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 15:01:46 +0100 Subject: [PATCH 5/9] fix test_publish.py::test_deserialize_client --- distributed/tests/test_publish.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index fe898d15131..f1107b57619 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -300,4 +300,4 @@ async def test_deserialize_client(c, s, a, b): # Ensure cleanup from distributed.client import _current_client - assert _current_client.get() is None + assert _current_client.get() is c From dee32bdee5894c94833e29907fde5ec490aadefc Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 16:33:31 +0100 Subject: [PATCH 6/9] Update distributed/tests/test_client.py --- distributed/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 25d668bf623..fe7c2b7195c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4026,6 +4026,7 @@ async def test_as_current(c, s, a, b): ) as c2: with temp_default_client(c): assert Client.current() is c + assert Client.current(allow_global=False) is c with c1.as_current(): assert Client.current() is c1 assert Client.current(allow_global=True) is c1 From 3843413c468267ba45a15adfd00c83842c2ee908 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 16:38:30 +0100 Subject: [PATCH 7/9] don't re-set the global client on reconnect --- distributed/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ac74ef80caf..ec55b7f2c03 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1403,8 +1403,6 @@ async def _ensure_connected(self, timeout=None): bcomm = BatchedSend(interval="10ms", loop=self.loop) bcomm.start(comm) self.scheduler_comm = bcomm - if self._set_as_default: - _set_global_client(self) self.status = "running" for msg in self._pending_msg_buffer: From 8bf3e8c9b91e7653bfc1d463e1a155d893028906 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 29 Jun 2023 16:47:58 +0100 Subject: [PATCH 8/9] test the new behaviour of Client.current --- distributed/tests/test_client.py | 120 ++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fe7c2b7195c..fe46790b990 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -62,6 +62,7 @@ Client, Future, _get_global_client, + _global_clients, as_completed, default_client, ensure_default_client, @@ -1142,10 +1143,10 @@ async def test_get_releases_data(c, s, a, b): await asyncio.sleep(0.01) -def test_current(s, a, b, loop_in_thread): - loop = loop_in_thread +def test_current(s, loop): with Client(s["address"], loop=loop) as c: assert Client.current() is c + assert Client.current(allow_global=False) is c with pytest.raises( ValueError, match=r"No clients found" @@ -1156,6 +1157,121 @@ def test_current(s, a, b, loop_in_thread): Client.current() with Client(s["address"], loop=loop) as c: assert Client.current() is c + assert Client.current(allow_global=False) is c + + +def test_current_nested(s, loop): + with pytest.raises( + ValueError, + match=r"No clients found" + r"\nStart a client and point it to the scheduler address" + r"\n from distributed import Client" + r"\n client = Client\('ip-addr-of-scheduler:8786'\)", + ): + Client.current() + + class MyException(Exception): + pass + + with Client(s["address"], loop=loop) as c_outer: + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with Client(s["address"], loop=loop) as c_inner: + assert Client.current() is c_inner + assert Client.current(allow_global=False) is c_inner + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with pytest.raises(MyException): + with Client(s["address"], loop=loop) as c_inner2: + assert Client.current() is c_inner2 + assert Client.current(allow_global=False) is c_inner2 + raise MyException + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + +@gen_cluster(nthreads=[]) +async def test_current_nested_async(s): + with pytest.raises( + ValueError, + match=r"No clients found" + r"\nStart a client and point it to the scheduler address" + r"\n from distributed import Client" + r"\n client = Client\('ip-addr-of-scheduler:8786'\)", + ): + Client.current() + + class MyException(Exception): + pass + + async with Client(s.address, asynchronous=True) as c_outer: + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + async with Client(s.address, asynchronous=True) as c_inner: + assert Client.current() is c_inner + assert Client.current(allow_global=False) is c_inner + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + with pytest.raises(MyException): + async with Client(s.address, asynchronous=True) as c_inner2: + assert Client.current() is c_inner2 + assert Client.current(allow_global=False) is c_inner2 + raise MyException + + assert Client.current() is c_outer + assert Client.current(allow_global=False) is c_outer + + +@gen_cluster(nthreads=[]) +async def test_current_concurrent(s): + client_1_started = asyncio.Event() + client_2_started = asyncio.Event() + stop_client_1 = asyncio.Event() + stop_client_2 = asyncio.Event() + client_2_stopped = asyncio.Event() + + c1 = None + c2 = None + + def _all_global_clients(): + return [v for _, v in sorted(_global_clients.items())] + + async def client_1(): + nonlocal c1 + async with Client(s.address, asynchronous=True) as c1: + assert _all_global_clients() == [c1] + assert Client.current() is c1 + client_1_started.set() + await client_2_started.wait() + # c2 is the highest priority global client + assert _all_global_clients() == [c1, c2] + # but the contextvar means the current client is still us + assert Client.current() is c1 + stop_client_2.set() + await stop_client_1.wait() + + async def client_2(): + nonlocal c2 + await client_1_started.wait() + async with Client(s.address, asynchronous=True) as c2: + assert _all_global_clients() == [c1, c2] + assert Client.current() is c2 + client_2_started.set() + await stop_client_2.wait() + + assert _all_global_clients() == [c1] + # Client.current() is now based on _global_clients instead of the cvar + assert Client.current() is c1 + stop_client_1.set() + + await asyncio.gather(client_1(), client_2()) def test_global_clients(loop): From 2538ac1361f0c3e6bce3675df81b04bfce24ccb7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 7 Jul 2023 11:39:15 +0100 Subject: [PATCH 9/9] add note --- distributed/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fe46790b990..2ed731a160e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3412,6 +3412,7 @@ async def test_get_scheduler_default_client_config_interleaving(s): @gen_cluster() async def test_ensure_default_client(s, a, b): + # Note: this test will fail if you use `async with Client` c = await Client(s.address, asynchronous=True) try: assert c is default_client()