Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken pool connection cleanup #230

Merged
merged 1 commit into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions edgedb/asyncio_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,12 @@ async def release(self, timeout):
'a free connection holder')

if self._con.is_closed():
# When closing, pool connections perform the necessary
# cleanup, so we don't have to do anything else here.
# This is usually the case when the connection is broken rather
# than closed by the user, so we need to call _release_on_close()
# here to release the holder back to the queue, because
# self._con._cleanup() was never called. On the other hand, it is
# safe to call self._release() twice - the second call is no-op.
self._release_on_close()
return

self._timeout = None
Expand Down
3 changes: 2 additions & 1 deletion edgedb/protocol/asyncio_proto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio

from edgedb import errors
from edgedb.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
Expand Down Expand Up @@ -109,7 +110,7 @@ cdef class AsyncIOProtocol(protocol.SansIOProtocol):
self.disconnected_fut.set_exception(ConnectionResetError())

if self.msg_waiter is not None and not self.msg_waiter.done():
self.msg_waiter.set_exception(ConnectionResetError())
self.msg_waiter.set_exception(errors.ClientConnectionClosedError())
self.msg_waiter = None

if self.transport is not None:
Expand Down
87 changes: 87 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from edgedb import _testbase as tb
from edgedb import asyncio_con
from edgedb import asyncio_pool
from edgedb import errors


class TestPool(tb.AsyncQueryTestCase):
Expand Down Expand Up @@ -566,3 +567,89 @@ async def test_pool_properties(self):
self.assertEqual(pool.free_size, max_size)

await pool.aclose()

async def _test_connection_broken(self, executor, broken_evt):
self.assertEqual(await executor.query_single("SELECT 123"), 123)
broken_evt.set()
with self.assertRaises(errors.ClientConnectionClosedError):
await executor.query_single("SELECT 123")
broken_evt.clear()
self.assertEqual(await executor.query_single("SELECT 123"), 123)

tested = False
async for tx in executor.retrying_transaction():
async with tx:
self.assertEqual(await tx.query_single("SELECT 123"), 123)
if tested:
break
tested = True
broken_evt.set()
try:
await tx.query_single("SELECT 123")
except errors.ClientConnectionClosedError:
broken_evt.clear()
raise
else:
self.fail("ConnectionError not raised!")

async def test_pool_connection_broken(self):
con_args = self.get_connect_args()
broken = asyncio.Event()
done = asyncio.Event()

async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter):
while True:
reader = self.loop.create_task(r.read(65536))
waiter = self.loop.create_task(broken.wait())
await asyncio.wait(
[reader, waiter],
return_when=asyncio.FIRST_COMPLETED,
)
if waiter.done():
reader.cancel()
w.close()
break
else:
waiter.cancel()
data = await reader
if not data:
break
w.write(data)

async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter):
ur, uw = await asyncio.open_connection(
con_args['host'], con_args['port']
)
done.clear()
task = self.loop.create_task(proxy(r, uw))
try:
await proxy(ur, w)
finally:
try:
await task
finally:
done.set()

server = await asyncio.start_server(
cb, '127.0.0.1', 0
)
port = server.sockets[0].getsockname()[1]
pool = await self.create_pool(
host='127.0.0.1', port=port, min_size=0, max_size=1
)
conargs = self.get_connect_args().copy()
conargs["database"] = self.con.dbname
conargs["timeout"] = 120
conargs["host"] = "127.0.0.1"
conargs["port"] = port
conn = await edgedb.async_connect(**conargs)
try:
await self._test_connection_broken(conn, broken)
await self._test_connection_broken(pool, broken)
finally:
server.close()
await server.wait_closed()
await asyncio.wait_for(pool.aclose(), 5)
await asyncio.wait_for(conn.aclose(), 1)
broken.set()
await done.wait()