Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set Client.as_current when entering ctx #6527

Merged
merged 11 commits into from
Jul 7, 2023
23 changes: 20 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _get_global_client() -> Client | None:

def _set_global_client(c: Client | None) -> None:
if c is not None:
c._set_as_default = True
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1

Expand Down Expand Up @@ -899,6 +900,7 @@ def __init__(
deserializers = serializers
self._deserializers = deserializers
self.direct_to_workers = direct_to_workers
self._previous_as_current = None

# Communication
self.scheduler_comm = None
Expand Down Expand Up @@ -1092,6 +1094,10 @@ def current(cls, allow_global=True):
------
ValueError
If there is no client set, a ValueError is raised

See also
--------
default_client
"""
out = _current_client.get()
if out:
Expand Down Expand Up @@ -1397,8 +1403,6 @@ async def _ensure_connected(self, timeout=None):
bcomm = BatchedSend(interval="10ms", loop=self.loop)
bcomm.start(comm)
self.scheduler_comm = bcomm
if self._set_as_default:
_set_global_client(self)
self.status = "running"

for msg in self._pending_msg_buffer:
Expand Down Expand Up @@ -1486,13 +1490,19 @@ def _heartbeat(self):
def __enter__(self):
if not self._loop_runner.is_started():
self.start()
if self._set_as_default:
self._previous_as_current = _current_client.set(self)
return self

async def __aenter__(self):
await self
if self._set_as_default:
self._previous_as_current = _current_client.set(self)
return self

async def __aexit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
await self._close(
# if we're handling an exception, we assume that it's more
# important to deliver that exception than shutdown gracefully.
Expand All @@ -1501,6 +1511,8 @@ async def __aexit__(self, exc_type, exc_value, traceback):
)

def __exit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
self.close()

def __del__(self):
Expand Down Expand Up @@ -5526,6 +5538,10 @@ def default_client(c=None):
-------
c : Client
The client, if one has started

See also
--------
Client.current (alias)
"""
c = c or _get_global_client()
if c:
Expand Down Expand Up @@ -5878,7 +5894,8 @@ def temp_default_client(c):
old_exec = default_client()
_set_global_client(c)
try:
yield
with c.as_current():
yield
finally:
_set_global_client(old_exec)

Expand Down
145 changes: 132 additions & 13 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
Client,
Future,
_get_global_client,
_global_clients,
as_completed,
default_client,
ensure_default_client,
Expand Down Expand Up @@ -1142,10 +1143,10 @@ async def test_get_releases_data(c, s, a, b):
await asyncio.sleep(0.01)


def test_current(s, a, b, loop_in_thread):
loop = loop_in_thread
def test_current(s, loop):
with Client(s["address"], loop=loop) as c:
assert Client.current() is c
assert Client.current(allow_global=False) is c
with pytest.raises(
ValueError,
match=r"No clients found"
Expand All @@ -1156,6 +1157,121 @@ def test_current(s, a, b, loop_in_thread):
Client.current()
with Client(s["address"], loop=loop) as c:
assert Client.current() is c
assert Client.current(allow_global=False) is c


def test_current_nested(s, loop):
with pytest.raises(
ValueError,
match=r"No clients found"
r"\nStart a client and point it to the scheduler address"
r"\n from distributed import Client"
r"\n client = Client\('ip-addr-of-scheduler:8786'\)",
):
Client.current()

class MyException(Exception):
pass

with Client(s["address"], loop=loop) as c_outer:
assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with Client(s["address"], loop=loop) as c_inner:
assert Client.current() is c_inner
assert Client.current(allow_global=False) is c_inner

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with pytest.raises(MyException):
with Client(s["address"], loop=loop) as c_inner2:
assert Client.current() is c_inner2
assert Client.current(allow_global=False) is c_inner2
raise MyException

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer


@gen_cluster(nthreads=[])
async def test_current_nested_async(s):
with pytest.raises(
ValueError,
match=r"No clients found"
r"\nStart a client and point it to the scheduler address"
r"\n from distributed import Client"
r"\n client = Client\('ip-addr-of-scheduler:8786'\)",
):
Client.current()

class MyException(Exception):
pass

async with Client(s.address, asynchronous=True) as c_outer:
assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

async with Client(s.address, asynchronous=True) as c_inner:
assert Client.current() is c_inner
assert Client.current(allow_global=False) is c_inner

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer

with pytest.raises(MyException):
async with Client(s.address, asynchronous=True) as c_inner2:
assert Client.current() is c_inner2
assert Client.current(allow_global=False) is c_inner2
raise MyException

assert Client.current() is c_outer
assert Client.current(allow_global=False) is c_outer


@gen_cluster(nthreads=[])
async def test_current_concurrent(s):
client_1_started = asyncio.Event()
client_2_started = asyncio.Event()
stop_client_1 = asyncio.Event()
stop_client_2 = asyncio.Event()
client_2_stopped = asyncio.Event()

c1 = None
c2 = None

def _all_global_clients():
return [v for _, v in sorted(_global_clients.items())]

async def client_1():
nonlocal c1
async with Client(s.address, asynchronous=True) as c1:
assert _all_global_clients() == [c1]
assert Client.current() is c1
client_1_started.set()
await client_2_started.wait()
# c2 is the highest priority global client
assert _all_global_clients() == [c1, c2]
# but the contextvar means the current client is still us
assert Client.current() is c1
stop_client_2.set()
await stop_client_1.wait()

async def client_2():
nonlocal c2
await client_1_started.wait()
async with Client(s.address, asynchronous=True) as c2:
assert _all_global_clients() == [c1, c2]
assert Client.current() is c2
client_2_started.set()
await stop_client_2.wait()

assert _all_global_clients() == [c1]
# Client.current() is now based on _global_clients instead of the cvar
assert Client.current() is c1
stop_client_1.set()

await asyncio.gather(client_1(), client_2())


def test_global_clients(loop):
Expand Down Expand Up @@ -3294,16 +3410,20 @@ async def test_get_scheduler_default_client_config_interleaving(s):
await client.close()


@gen_cluster(client=True)
async def test_ensure_default_client(c, s, a, b):
assert c is default_client()

async with Client(s.address, set_as_default=False, asynchronous=True) as c2:
@gen_cluster()
async def test_ensure_default_client(s, a, b):
c = await Client(s.address, asynchronous=True)
try:
Comment on lines +3416 to +3417
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we can only set current when entering ctx when set_as_default is also set. Otherwise this breaks tests like this.

That also makes sense considering that default_client and Client.current are effectively aliases

assert c is default_client()
assert c2 is not default_client()
ensure_default_client(c2)
assert c is not default_client()
assert c2 is default_client()

async with Client(s.address, set_as_default=False, asynchronous=True) as c2:
assert c is default_client()
assert c2 is not default_client()
ensure_default_client(c2)
assert c is not default_client()
assert c2 is default_client()
finally:
await c.close()


@gen_cluster()
Expand Down Expand Up @@ -4022,8 +4142,7 @@ async def test_as_current(c, s, a, b):
) as c2:
with temp_default_client(c):
assert Client.current() is c
with pytest.raises(ValueError):
Client.current(allow_global=False)
assert Client.current(allow_global=False) is c
with c1.as_current():
assert Client.current() is c1
assert Client.current(allow_global=True) is c1
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,15 @@ def event_is_set(event_name):
assert not s.extensions["events"]._waiter_count


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.

This typically happens if the object is being deserialized on the scheduler.
"""
obj = await Event()
pickled = pickle.dumps(obj)
await c.close()
async with Client(s.address, asynchronous=True) as c:
obj = await Event()
pickled = pickle.dumps(obj)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,4 @@ async def test_deserialize_client(c, s, a, b):
# Ensure cleanup
from distributed.client import _current_client

assert _current_client.get() is None
assert _current_client.get() is c
10 changes: 5 additions & 5 deletions distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ def foo():
assert result == 123


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.

This typically happens if the object is being deserialized on the scheduler.
"""
q = await Queue()
pickled = pickle.dumps(q)
await c.close()
async with Client(s.address, asynchronous=True) as c:
q = await Queue()
pickled = pickle.dumps(q)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,15 @@ async def test_release_failure(c, s, a, b, caplog):
await pool.close()


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.

This typically happens if the object is being deserialized on the scheduler.
"""
sem = await Semaphore()
pickled = pickle.dumps(sem)
await c.close()
async with Client(s.address, asynchronous=True) as c:
sem = await Semaphore()
pickled = pickle.dumps(sem)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,15 @@ async def test_variables_do_not_leak_client(c, s, a, b):
assert time() < start + 5


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.

This typically happens if the object is being deserialized on the scheduler.
"""
obj = Variable("foo")
pickled = pickle.dumps(obj)
await c.close()
async with Client(s.address, asynchronous=True) as c:
obj = Variable("foo")
pickled = pickle.dumps(obj)

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Expand Down