From f4cad66ecc26ceb0a8287f400f197a98464bfb36 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 24 Aug 2021 10:43:38 -0400 Subject: [PATCH] Fix broken pool connection cleanup When a connection is broken (not actively closed by the user), that connection holder in the pool is not cleaned, leading to issues like creating zombie connections in the pool or pool.aclose() hangs forever. To test retrying_transaction(), connection errors from wait_for_message() is wrapped into a retryable EdgeDB client error type. This is partially fixing the same issue in #222, but the latter should aim for a more complete solution. --- edgedb/asyncio_pool.py | 8 ++- edgedb/protocol/asyncio_proto.pyx | 3 +- tests/test_pool.py | 87 +++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 3 deletions(-) diff --git a/edgedb/asyncio_pool.py b/edgedb/asyncio_pool.py index 2d3211f5..255f99d6 100644 --- a/edgedb/asyncio_pool.py +++ b/edgedb/asyncio_pool.py @@ -146,8 +146,12 @@ async def release(self, timeout): 'a free connection holder') if self._con.is_closed(): - # When closing, pool connections perform the necessary - # cleanup, so we don't have to do anything else here. + # This is usually the case when the connection is broken rather + # than closed by the user, so we need to call _release_on_close() + # here to release the holder back to the queue, because + # self._con._cleanup() was never called. On the other hand, it is + # safe to call self._release() twice - the second call is no-op. + self._release_on_close() return self._timeout = None diff --git a/edgedb/protocol/asyncio_proto.pyx b/edgedb/protocol/asyncio_proto.pyx index 2df4aa97..f00e0b10 100644 --- a/edgedb/protocol/asyncio_proto.pyx +++ b/edgedb/protocol/asyncio_proto.pyx @@ -19,6 +19,7 @@ import asyncio +from edgedb import errors from edgedb.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, @@ -109,7 +110,7 @@ cdef class AsyncIOProtocol(protocol.SansIOProtocol): self.disconnected_fut.set_exception(ConnectionResetError()) if self.msg_waiter is not None and not self.msg_waiter.done(): - self.msg_waiter.set_exception(ConnectionResetError()) + self.msg_waiter.set_exception(errors.ClientConnectionClosedError()) self.msg_waiter = None if self.transport is not None: diff --git a/tests/test_pool.py b/tests/test_pool.py index dbfd2dd0..ba2bc8fc 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -27,6 +27,7 @@ from edgedb import _testbase as tb from edgedb import asyncio_con from edgedb import asyncio_pool +from edgedb import errors class TestPool(tb.AsyncQueryTestCase): @@ -566,3 +567,89 @@ async def test_pool_properties(self): self.assertEqual(pool.free_size, max_size) await pool.aclose() + + async def _test_connection_broken(self, executor, broken_evt): + self.assertEqual(await executor.query_single("SELECT 123"), 123) + broken_evt.set() + with self.assertRaises(errors.ClientConnectionClosedError): + await executor.query_single("SELECT 123") + broken_evt.clear() + self.assertEqual(await executor.query_single("SELECT 123"), 123) + + tested = False + async for tx in executor.retrying_transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 123"), 123) + if tested: + break + tested = True + broken_evt.set() + try: + await tx.query_single("SELECT 123") + except errors.ClientConnectionClosedError: + broken_evt.clear() + raise + else: + self.fail("ConnectionError not raised!") + + async def test_pool_connection_broken(self): + con_args = self.get_connect_args() + broken = asyncio.Event() + done = asyncio.Event() + + async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): + while True: + reader = self.loop.create_task(r.read(65536)) + waiter = self.loop.create_task(broken.wait()) + await asyncio.wait( + [reader, waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + if waiter.done(): + reader.cancel() + w.close() + break + else: + waiter.cancel() + data = await reader + if not data: + break + w.write(data) + + async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): + ur, uw = await asyncio.open_connection( + con_args['host'], con_args['port'] + ) + done.clear() + task = self.loop.create_task(proxy(r, uw)) + try: + await proxy(ur, w) + finally: + try: + await task + finally: + done.set() + + server = await asyncio.start_server( + cb, '127.0.0.1', 0 + ) + port = server.sockets[0].getsockname()[1] + pool = await self.create_pool( + host='127.0.0.1', port=port, min_size=0, max_size=1 + ) + conargs = self.get_connect_args().copy() + conargs["database"] = self.con.dbname + conargs["timeout"] = 120 + conargs["host"] = "127.0.0.1" + conargs["port"] = port + conn = await edgedb.async_connect(**conargs) + try: + await self._test_connection_broken(conn, broken) + await self._test_connection_broken(pool, broken) + finally: + server.close() + await server.wait_closed() + await asyncio.wait_for(pool.aclose(), 5) + await asyncio.wait_for(conn.aclose(), 1) + broken.set() + await done.wait()