Skip to content

Commit

Permalink
Close comm on CancelledError (dask#5656)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 13, 2022
1 parent 00cce09 commit 3e2e880
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
20 changes: 15 additions & 5 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ async def handle_comm(self, comm):
result = asyncio.ensure_future(result)
self._ongoing_coroutines.add(result)
result = await result
except (CommClosedError, CancelledError):
except (CommClosedError, asyncio.CancelledError):
if self.status in (Status.running, Status.paused):
logger.info("Lost connection to %r", address, exc_info=True)
break
Expand Down Expand Up @@ -663,15 +663,23 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw
response = await comm.read(deserializers=deserializers)
else:
response = None
except OSError:
except (asyncio.TimeoutError, OSError):
# On communication errors, we should simply close the communication
# Note that OSError includes CommClosedError and socket timeouts
force_close = True
raise
except asyncio.CancelledError:
# Do not reuse the comm to prevent the next call of send_recv from receiving
# data from this call and/or accidentally putting multiple waiters on read().
# Note that this relies on all Comm implementations to allow a write() in the
# middle of a read().
please_close = True
raise
finally:
if please_close:
await comm.close()
elif force_close:
if force_close:
comm.abort()
elif please_close:
await comm.close()

if isinstance(response, dict) and response.get("status") == "uncaught-error":
if comm.deserialize:
Expand Down Expand Up @@ -1084,6 +1092,8 @@ def reuse(self, addr, comm):
else:
self.occupied[addr].remove(comm)
if comm.closed():
# Either the user passed the close=True parameter to send_recv, or
# the RPC call raised OSError or CancelledError
self.semaphore.release()
else:
self.available[addr].add(comm)
Expand Down
23 changes: 23 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
async_wait_for,
captured_logger,
gen_cluster,
gen_test,
has_ipv6,
inc,
throws,
Expand Down Expand Up @@ -531,6 +532,28 @@ async def test_send_recv_args():
server.stop()


@gen_test(timeout=5)
async def test_send_recv_cancelled():
"""Test that the comm channel is closed on CancelledError"""

async def get_stuck(comm):
await asyncio.Future()

server = Server({"get_stuck": get_stuck})
await server.listen(0)

client_comm = await connect(server.address, deserialize=False)
while not server._comms:
await asyncio.sleep(0.01)
server_comm = next(iter(server._comms))

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(send_recv(client_comm, op="get_stuck"), timeout=0.1)
assert client_comm.closed()
while not server_comm.closed():
await asyncio.sleep(0.01)


def test_coerce_to_address():
for arg in ["127.0.0.1:8786", ("127.0.0.1", 8786), ("127.0.0.1", "8786")]:
assert coerce_to_address(arg) == "tcp://127.0.0.1:8786"
Expand Down

0 comments on commit 3e2e880

Please sign in to comment.