Skip to content

Commit b0deeac

Browse files
committed
Allow customizing connection state reset
A coroutine can be passed to the new `reset` argument of `create_pool` to control what happens to the connection when it is returned back to the pool by `release()`. By default `Connection.reset()` is called. Additionally, `Connection.get_reset_query` is renamed from `Connection._get_reset_query` to enable an alternative way of customizing the reset process via subclassing. Closes: #780 Closes: #1146
1 parent 3ee19ba commit b0deeac

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

Diff for: asyncpg/connection.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1515,11 +1515,10 @@ def terminate(self):
15151515
self._abort()
15161516
self._cleanup()
15171517

1518-
async def reset(self, *, timeout=None):
1518+
async def _reset(self):
15191519
self._check_open()
15201520
self._listeners.clear()
15211521
self._log_listeners.clear()
1522-
reset_query = self._get_reset_query()
15231522

15241523
if self._protocol.is_in_transaction() or self._top_xact is not None:
15251524
if self._top_xact is None or not self._top_xact._managed:
@@ -1531,10 +1530,14 @@ async def reset(self, *, timeout=None):
15311530
})
15321531

15331532
self._top_xact = None
1534-
reset_query = 'ROLLBACK;\n' + reset_query
1533+
await self.execute("ROLLBACK")
15351534

1536-
if reset_query:
1537-
await self.execute(reset_query, timeout=timeout)
1535+
async def reset(self, *, timeout=None):
1536+
async with compat.timeout(timeout):
1537+
await self._reset()
1538+
reset_query = self.get_reset_query()
1539+
if reset_query:
1540+
await self.execute(reset_query)
15381541

15391542
def _abort(self):
15401543
# Put the connection into the aborted state.
@@ -1695,7 +1698,13 @@ def _unwrap(self):
16951698
con_ref = self._proxy
16961699
return con_ref
16971700

1698-
def _get_reset_query(self):
1701+
def get_reset_query(self):
1702+
"""Return the query sent to server on connection release.
1703+
1704+
The query returned by this method is used by :meth:`Connection.reset`,
1705+
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
1706+
the connection available to another acquirer.
1707+
"""
16991708
if self._reset_query is not None:
17001709
return self._reset_query
17011710

Diff for: asyncpg/pool.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ async def release(self, timeout):
210210
if budget is not None:
211211
budget -= time.monotonic() - started
212212

213-
await self._con.reset(timeout=budget)
213+
if self._pool._reset is not None:
214+
async with compat.timeout(budget):
215+
await self._con._reset()
216+
await self._pool._reset(self._con)
217+
else:
218+
await self._con.reset(timeout=budget)
214219
except (Exception, asyncio.CancelledError) as ex:
215220
# If the `reset` call failed, terminate the connection.
216221
# A new one will be created when `acquire` is called
@@ -313,7 +318,7 @@ class Pool:
313318

314319
__slots__ = (
315320
'_queue', '_loop', '_minsize', '_maxsize',
316-
'_init', '_connect', '_connect_args', '_connect_kwargs',
321+
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
317322
'_holders', '_initialized', '_initializing', '_closing',
318323
'_closed', '_connection_class', '_record_class', '_generation',
319324
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -327,6 +332,7 @@ def __init__(self, *connect_args,
327332
connect=None,
328333
setup=None,
329334
init=None,
335+
reset=None,
330336
loop,
331337
connection_class,
332338
record_class,
@@ -393,6 +399,7 @@ def __init__(self, *connect_args,
393399

394400
self._setup = setup
395401
self._init = init
402+
self._reset = reset
396403

397404
self._max_queries = max_queries
398405
self._max_inactive_connection_lifetime = \
@@ -1036,6 +1043,7 @@ def create_pool(dsn=None, *,
10361043
connect=None,
10371044
setup=None,
10381045
init=None,
1046+
reset=None,
10391047
loop=None,
10401048
connection_class=connection.Connection,
10411049
record_class=protocol.Record,
@@ -1137,6 +1145,25 @@ def create_pool(dsn=None, *,
11371145
or :meth:`Connection.set_type_codec() <\
11381146
asyncpg.connection.Connection.set_type_codec>`.
11391147
1148+
:param coroutine reset:
1149+
A coroutine to reset a connection before it is returned to the pool by
1150+
:meth:`Pool.release() <pool.Pool.release>`. The function is supposed
1151+
to reset any changes made to the database session so that the next
1152+
acquirer gets the connection in a well-defined state.
1153+
1154+
The default implementation calls :meth:`Connection.reset() <\
1155+
asyncpg.connection.Connection.reset>`, which runs the following::
1156+
1157+
SELECT pg_advisory_unlock_all();
1158+
CLOSE ALL;
1159+
UNLISTEN *;
1160+
RESET ALL;
1161+
1162+
The exact reset query is determined by detected server capabilities,
1163+
and a custom *reset* implementation can obtain the default query
1164+
by calling :meth:`Connection.get_reset_query() <\
1165+
asyncpg.connection.Connection.get_reset_query>`.
1166+
11401167
:param loop:
11411168
An asyncio event loop instance. If ``None``, the default
11421169
event loop will be used.
@@ -1165,7 +1192,7 @@ def create_pool(dsn=None, *,
11651192
Added the *record_class* parameter.
11661193
11671194
.. versionchanged:: 0.30.0
1168-
Added the *connect* parameter.
1195+
Added the *connect* and *reset* parameters.
11691196
"""
11701197
return Pool(
11711198
dsn,
@@ -1178,6 +1205,7 @@ def create_pool(dsn=None, *,
11781205
connect=connect,
11791206
setup=setup,
11801207
init=init,
1208+
reset=reset,
11811209
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
11821210
**connect_kwargs,
11831211
)

Diff for: tests/test_pool.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,31 @@ async def setup(con):
137137
async def test_pool_07(self):
138138
cons = set()
139139
connect_called = 0
140+
init_called = 0
141+
setup_called = 0
142+
reset_called = 0
140143

141144
async def connect(*args, **kwargs):
142145
nonlocal connect_called
143146
connect_called += 1
144147
return await pg_connection.connect(*args, **kwargs)
145148

146149
async def setup(con):
150+
nonlocal setup_called
147151
if con._con not in cons: # `con` is `PoolConnectionProxy`.
148152
raise RuntimeError('init was not called before setup')
153+
setup_called += 1
149154

150155
async def init(con):
156+
nonlocal init_called
151157
if con in cons:
152158
raise RuntimeError('init was called more than once')
153159
cons.add(con)
160+
init_called += 1
161+
162+
async def reset(con):
163+
nonlocal reset_called
164+
reset_called += 1
154165

155166
async def user(pool):
156167
async with pool.acquire() as con:
@@ -162,12 +173,16 @@ async def user(pool):
162173
max_size=5,
163174
connect=connect,
164175
init=init,
165-
setup=setup) as pool:
176+
setup=setup,
177+
reset=reset) as pool:
166178
users = asyncio.gather(*[user(pool) for _ in range(10)])
167179
await users
168180

169181
self.assertEqual(len(cons), 5)
170182
self.assertEqual(connect_called, 5)
183+
self.assertEqual(init_called, 5)
184+
self.assertEqual(setup_called, 10)
185+
self.assertEqual(reset_called, 10)
171186

172187
async def bad_connect(*args, **kwargs):
173188
return 1

0 commit comments

Comments
 (0)