Skip to content

Commit f6ec755

Browse files
authored
Allow customizing connection state reset (#1191)
A coroutine can be passed to the new `reset` argument of `create_pool` to control what happens to the connection when it is returned back to the pool by `release()`. By default `Connection.reset()` is called. Additionally, `Connection.get_reset_query` is renamed from `Connection._get_reset_query` to enable an alternative way of customizing the reset process via subclassing. Closes: #780 Closes: #1146
1 parent 3ef884e commit f6ec755

File tree

3 files changed

+87
-11
lines changed

3 files changed

+87
-11
lines changed

asyncpg/connection.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -1515,11 +1515,10 @@ def terminate(self):
15151515
self._abort()
15161516
self._cleanup()
15171517

1518-
async def reset(self, *, timeout=None):
1518+
async def _reset(self):
15191519
self._check_open()
15201520
self._listeners.clear()
15211521
self._log_listeners.clear()
1522-
reset_query = self._get_reset_query()
15231522

15241523
if self._protocol.is_in_transaction() or self._top_xact is not None:
15251524
if self._top_xact is None or not self._top_xact._managed:
@@ -1531,10 +1530,36 @@ async def reset(self, *, timeout=None):
15311530
})
15321531

15331532
self._top_xact = None
1534-
reset_query = 'ROLLBACK;\n' + reset_query
1533+
await self.execute("ROLLBACK")
1534+
1535+
async def reset(self, *, timeout=None):
1536+
"""Reset the connection state.
1537+
1538+
Calling this will reset the connection session state to a state
1539+
resembling that of a newly obtained connection. Namely, an open
1540+
transaction (if any) is rolled back, open cursors are closed,
1541+
all `LISTEN <https://www.postgresql.org/docs/current/sql-listen.html>`_
1542+
registrations are removed, all session configuration
1543+
variables are reset to their default values, and all advisory locks
1544+
are released.
1545+
1546+
Note that the above describes the default query returned by
1547+
:meth:`Connection.get_reset_query`. If one overloads the method
1548+
by subclassing ``Connection``, then this method will do whatever
1549+
the overloaded method returns, except open transactions are always
1550+
terminated and any callbacks registered by
1551+
:meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
1552+
are removed.
15351553
1536-
if reset_query:
1537-
await self.execute(reset_query, timeout=timeout)
1554+
:param float timeout:
1555+
A timeout for resetting the connection. If not specified, defaults
1556+
to no timeout.
1557+
"""
1558+
async with compat.timeout(timeout):
1559+
await self._reset()
1560+
reset_query = self.get_reset_query()
1561+
if reset_query:
1562+
await self.execute(reset_query)
15381563

15391564
def _abort(self):
15401565
# Put the connection into the aborted state.
@@ -1695,7 +1720,15 @@ def _unwrap(self):
16951720
con_ref = self._proxy
16961721
return con_ref
16971722

1698-
def _get_reset_query(self):
1723+
def get_reset_query(self):
1724+
"""Return the query sent to server on connection release.
1725+
1726+
The query returned by this method is used by :meth:`Connection.reset`,
1727+
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
1728+
the connection available to another acquirer.
1729+
1730+
.. versionadded:: 0.30.0
1731+
"""
16991732
if self._reset_query is not None:
17001733
return self._reset_query
17011734

asyncpg/pool.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ async def release(self, timeout):
210210
if budget is not None:
211211
budget -= time.monotonic() - started
212212

213-
await self._con.reset(timeout=budget)
213+
if self._pool._reset is not None:
214+
async with compat.timeout(budget):
215+
await self._con._reset()
216+
await self._pool._reset(self._con)
217+
else:
218+
await self._con.reset(timeout=budget)
214219
except (Exception, asyncio.CancelledError) as ex:
215220
# If the `reset` call failed, terminate the connection.
216221
# A new one will be created when `acquire` is called
@@ -313,7 +318,7 @@ class Pool:
313318

314319
__slots__ = (
315320
'_queue', '_loop', '_minsize', '_maxsize',
316-
'_init', '_connect', '_connect_args', '_connect_kwargs',
321+
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
317322
'_holders', '_initialized', '_initializing', '_closing',
318323
'_closed', '_connection_class', '_record_class', '_generation',
319324
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -327,6 +332,7 @@ def __init__(self, *connect_args,
327332
connect=None,
328333
setup=None,
329334
init=None,
335+
reset=None,
330336
loop,
331337
connection_class,
332338
record_class,
@@ -393,6 +399,7 @@ def __init__(self, *connect_args,
393399

394400
self._setup = setup
395401
self._init = init
402+
self._reset = reset
396403

397404
self._max_queries = max_queries
398405
self._max_inactive_connection_lifetime = \
@@ -1036,6 +1043,7 @@ def create_pool(dsn=None, *,
10361043
connect=None,
10371044
setup=None,
10381045
init=None,
1046+
reset=None,
10391047
loop=None,
10401048
connection_class=connection.Connection,
10411049
record_class=protocol.Record,
@@ -1125,7 +1133,7 @@ def create_pool(dsn=None, *,
11251133
11261134
:param coroutine setup:
11271135
A coroutine to prepare a connection right before it is returned
1128-
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
1136+
from :meth:`Pool.acquire()`. An example use
11291137
case would be to automatically set up notifications listeners for
11301138
all connections of a pool.
11311139
@@ -1137,6 +1145,25 @@ def create_pool(dsn=None, *,
11371145
or :meth:`Connection.set_type_codec() <\
11381146
asyncpg.connection.Connection.set_type_codec>`.
11391147
1148+
:param coroutine reset:
1149+
A coroutine to reset a connection before it is returned to the pool by
1150+
:meth:`Pool.release()`. The function is supposed
1151+
to reset any changes made to the database session so that the next
1152+
acquirer gets the connection in a well-defined state.
1153+
1154+
The default implementation calls :meth:`Connection.reset() <\
1155+
asyncpg.connection.Connection.reset>`, which runs the following::
1156+
1157+
SELECT pg_advisory_unlock_all();
1158+
CLOSE ALL;
1159+
UNLISTEN *;
1160+
RESET ALL;
1161+
1162+
The exact reset query is determined by detected server capabilities,
1163+
and a custom *reset* implementation can obtain the default query
1164+
by calling :meth:`Connection.get_reset_query() <\
1165+
asyncpg.connection.Connection.get_reset_query>`.
1166+
11401167
:param loop:
11411168
An asyncio event loop instance. If ``None``, the default
11421169
event loop will be used.
@@ -1165,7 +1192,7 @@ def create_pool(dsn=None, *,
11651192
Added the *record_class* parameter.
11661193
11671194
.. versionchanged:: 0.30.0
1168-
Added the *connect* parameter.
1195+
Added the *connect* and *reset* parameters.
11691196
"""
11701197
return Pool(
11711198
dsn,
@@ -1178,6 +1205,7 @@ def create_pool(dsn=None, *,
11781205
connect=connect,
11791206
setup=setup,
11801207
init=init,
1208+
reset=reset,
11811209
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
11821210
**connect_kwargs,
11831211
)

tests/test_pool.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,31 @@ async def setup(con):
137137
async def test_pool_07(self):
138138
cons = set()
139139
connect_called = 0
140+
init_called = 0
141+
setup_called = 0
142+
reset_called = 0
140143

141144
async def connect(*args, **kwargs):
142145
nonlocal connect_called
143146
connect_called += 1
144147
return await pg_connection.connect(*args, **kwargs)
145148

146149
async def setup(con):
150+
nonlocal setup_called
147151
if con._con not in cons: # `con` is `PoolConnectionProxy`.
148152
raise RuntimeError('init was not called before setup')
153+
setup_called += 1
149154

150155
async def init(con):
156+
nonlocal init_called
151157
if con in cons:
152158
raise RuntimeError('init was called more than once')
153159
cons.add(con)
160+
init_called += 1
161+
162+
async def reset(con):
163+
nonlocal reset_called
164+
reset_called += 1
154165

155166
async def user(pool):
156167
async with pool.acquire() as con:
@@ -162,12 +173,16 @@ async def user(pool):
162173
max_size=5,
163174
connect=connect,
164175
init=init,
165-
setup=setup) as pool:
176+
setup=setup,
177+
reset=reset) as pool:
166178
users = asyncio.gather(*[user(pool) for _ in range(10)])
167179
await users
168180

169181
self.assertEqual(len(cons), 5)
170182
self.assertEqual(connect_called, 5)
183+
self.assertEqual(init_called, 5)
184+
self.assertEqual(setup_called, 10)
185+
self.assertEqual(reset_called, 10)
171186

172187
async def bad_connect(*args, **kwargs):
173188
return 1

0 commit comments

Comments
 (0)