Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix on_connect multiple call on acquire #552

Merged
merged 4 commits into from
Dec 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiopg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ async def _acquire(self):
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
self._used.add(conn)
if self._on_connect is not None:
await self._on_connect(conn)
return conn
else:
await self._cond.wait()
Expand Down Expand Up @@ -203,6 +201,8 @@ async def _fill_free_pool(self, override_min):
enable_uuid=self._enable_uuid,
echo=self._echo,
**self._conn_kwargs)
if self._on_connect is not None:
await self._on_connect(conn)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
Expand All @@ -221,6 +221,8 @@ async def _fill_free_pool(self, override_min):
enable_uuid=self._enable_uuid,
echo=self._echo,
**self._conn_kwargs)
if self._on_connect is not None:
await self._on_connect(conn)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
Expand Down
2 changes: 1 addition & 1 deletion docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ The basic usage is::

:param bool echo: executed log SQL queryes (disabled by default).

:param on_connect: a *callback coroutine* executed at once for every
:param on_connect: a *callback coroutine* executed once for every
created connection. May be used for setting up connection level
state like client encoding etc.

Expand Down
20 changes: 14 additions & 6 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,20 +548,28 @@ async def test_close_running_cursor(create_pool):
await cur.execute('SELECT pg_sleep(10)')


async def test_pool_on_connect(create_pool):
called = False
@pytest.mark.parametrize('pool_minsize', [0, 1])
async def test_pool_on_connect(create_pool, pool_minsize):
cb_called_times = 0

async def cb(connection):
nonlocal called
nonlocal cb_called_times
async with connection.cursor() as cur:
await cur.execute('SELECT 1')
data = await cur.fetchall()
assert [(1,)] == data
called = True
cb_called_times += 1

pool = await create_pool(on_connect=cb)
pool = await create_pool(
minsize=pool_minsize,
maxsize=1,
on_connect=cb
)

with (await pool.cursor()) as cur:
await cur.execute('SELECT 1')

with (await pool.cursor()) as cur:
await cur.execute('SELECT 1')

assert called
assert cb_called_times == 1