diff --git a/edgedb/asyncio_pool.py b/edgedb/asyncio_pool.py index 2d3211f5..255f99d6 100644 --- a/edgedb/asyncio_pool.py +++ b/edgedb/asyncio_pool.py @@ -146,8 +146,12 @@ async def release(self, timeout): 'a free connection holder') if self._con.is_closed(): - # When closing, pool connections perform the necessary - # cleanup, so we don't have to do anything else here. + # This is usually the case when the connection is broken rather + # than closed by the user, so we need to call _release_on_close() + # here to release the holder back to the queue, because + # self._con._cleanup() was never called. On the other hand, it is + # safe to call self._release() twice - the second call is no-op. + self._release_on_close() return self._timeout = None diff --git a/edgedb/protocol/asyncio_proto.pyx b/edgedb/protocol/asyncio_proto.pyx index 2df4aa97..f00e0b10 100644 --- a/edgedb/protocol/asyncio_proto.pyx +++ b/edgedb/protocol/asyncio_proto.pyx @@ -19,6 +19,7 @@ import asyncio +from edgedb import errors from edgedb.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, @@ -109,7 +110,7 @@ cdef class AsyncIOProtocol(protocol.SansIOProtocol): self.disconnected_fut.set_exception(ConnectionResetError()) if self.msg_waiter is not None and not self.msg_waiter.done(): - self.msg_waiter.set_exception(ConnectionResetError()) + self.msg_waiter.set_exception(errors.ClientConnectionClosedError()) self.msg_waiter = None if self.transport is not None: diff --git a/tests/test_pool.py b/tests/test_pool.py index dbfd2dd0..ba2bc8fc 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -27,6 +27,7 @@ from edgedb import _testbase as tb from edgedb import asyncio_con from edgedb import asyncio_pool +from edgedb import errors class TestPool(tb.AsyncQueryTestCase): @@ -566,3 +567,89 @@ async def test_pool_properties(self): self.assertEqual(pool.free_size, max_size) await pool.aclose() + + async def _test_connection_broken(self, executor, broken_evt): + self.assertEqual(await executor.query_single("SELECT 123"), 123) + broken_evt.set() + with self.assertRaises(errors.ClientConnectionClosedError): + await executor.query_single("SELECT 123") + broken_evt.clear() + self.assertEqual(await executor.query_single("SELECT 123"), 123) + + tested = False + async for tx in executor.retrying_transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 123"), 123) + if tested: + break + tested = True + broken_evt.set() + try: + await tx.query_single("SELECT 123") + except errors.ClientConnectionClosedError: + broken_evt.clear() + raise + else: + self.fail("ConnectionError not raised!") + + async def test_pool_connection_broken(self): + con_args = self.get_connect_args() + broken = asyncio.Event() + done = asyncio.Event() + + async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): + while True: + reader = self.loop.create_task(r.read(65536)) + waiter = self.loop.create_task(broken.wait()) + await asyncio.wait( + [reader, waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + if waiter.done(): + reader.cancel() + w.close() + break + else: + waiter.cancel() + data = await reader + if not data: + break + w.write(data) + + async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): + ur, uw = await asyncio.open_connection( + con_args['host'], con_args['port'] + ) + done.clear() + task = self.loop.create_task(proxy(r, uw)) + try: + await proxy(ur, w) + finally: + try: + await task + finally: + done.set() + + server = await asyncio.start_server( + cb, '127.0.0.1', 0 + ) + port = server.sockets[0].getsockname()[1] + pool = await self.create_pool( + host='127.0.0.1', port=port, min_size=0, max_size=1 + ) + conargs = self.get_connect_args().copy() + conargs["database"] = self.con.dbname + conargs["timeout"] = 120 + conargs["host"] = "127.0.0.1" + conargs["port"] = port + conn = await edgedb.async_connect(**conargs) + try: + await self._test_connection_broken(conn, broken) + await self._test_connection_broken(pool, broken) + finally: + server.close() + await server.wait_closed() + await asyncio.wait_for(pool.aclose(), 5) + await asyncio.wait_for(conn.aclose(), 1) + broken.set() + await done.wait()