Skip to content

Commit c2adad5

Browse files
committed
Comments, more unittests
1 parent bb38776 commit c2adad5

File tree

4 files changed

+75
-15
lines changed

4 files changed

+75
-15
lines changed

asyncpg/_testbase.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,32 @@ def start_cluster(cls, ClusterCls, *,
190190
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)
191191

192192

193+
def with_connection_options(**options):
194+
if not options:
195+
raise ValueError('no connection options were specified')
196+
197+
def wrap(func):
198+
func.__connect_options__ = options
199+
return func
200+
201+
return wrap
202+
203+
193204
class ConnectedTestCase(ClusterTestCase):
194205

195206
def getExtraConnectOptions(self):
196207
return {}
197208

198209
def setUp(self):
199210
super().setUp()
200-
opts = self.getExtraConnectOptions()
211+
212+
# Extract options set up with `with_connection_options`.
213+
test_func = getattr(self, self._testMethodName).__func__
214+
opts = getattr(test_func, '__connect_options__', {})
215+
201216
self.con = self.loop.run_until_complete(
202217
self.cluster.connect(database='postgres', loop=self.loop, **opts))
218+
203219
self.server_version = self.con.get_server_version()
204220

205221
def tearDown(self):

asyncpg/connection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,10 +880,13 @@ def set_max_lifetime(self, new_lifetime):
880880
assert new_lifetime >= 0
881881
self._max_lifetime = new_lifetime
882882
for entry in self._entries.values():
883+
# For every entry cancel the existing callback
884+
# and setup a new one if necessary.
883885
self._set_entry_timeout(entry)
884886

885887
def get(self, query, *, promote=True):
886888
if not self._max_size:
889+
# The cache is disabled.
887890
return
888891

889892
entry = self._entries.get(query) # type: _StatementCacheEntry
@@ -898,6 +901,7 @@ def get(self, query, *, promote=True):
898901
return
899902

900903
if promote:
904+
# `promote` is `False` when `get()` is called by `has()`.
901905
self._entries.move_to_end(query, last=True)
902906

903907
return entry._statement
@@ -907,21 +911,31 @@ def has(self, query):
907911

908912
def put(self, query, statement):
909913
if not self._max_size:
914+
# The cache is disabled.
910915
return
911916

912917
self._entries[query] = self._new_entry(query, statement)
918+
919+
# Check if the cache is bigger than max_size and trim it
920+
# if necessary.
913921
self._maybe_cleanup()
914922

915923
def iter_statements(self):
916924
return (e._statement for e in self._entries.values())
917925

918926
def clear(self):
927+
# First, make sure that we cancel all scheduled callbacks.
919928
for entry in self._entries.values():
920929
self._clear_entry_callback(entry)
930+
931+
# Clear the entries dict.
921932
self._entries.clear()
922933

923934
def _set_entry_timeout(self, entry):
935+
# Clear the existing timeout.
924936
self._clear_entry_callback(entry)
937+
938+
# Set the new timeout if it's not 0.
925939
if self._max_lifetime:
926940
entry._cleanup_cb = self._loop.call_later(
927941
self._max_lifetime, self._on_entry_expired, entry)
@@ -943,9 +957,13 @@ def _clear_entry_callback(self, entry):
943957
entry._cleanup_cb.cancel()
944958

945959
def _maybe_cleanup(self):
960+
# Delete cache entries until the size of the cache is `max_size`.
946961
while len(self._entries) > self._max_size:
947962
old_query, old_entry = self._entries.popitem(last=False)
948963
self._clear_entry_callback(old_entry)
964+
965+
# Let the connection know that the statement was removed
966+
# from the cache.
949967
self._on_remove(old_entry._statement)
950968

951969

tests/test_prepare.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,9 @@ async def test_prepare_statement_invalid(self):
430430
finally:
431431
await self.con.execute('DROP TABLE tab1')
432432

433+
@tb.with_connection_options(statement_cache_size=0)
433434
async def test_prepare_23_no_stmt_cache_seq(self):
434-
# Disable cache, which will force connections to use
435-
# anonymous prepared statements.
436-
self.con._stmt_cache.set_max_size(0)
435+
self.assertEqual(self.con._stmt_cache.get_max_size(), 0)
437436

438437
async def check_simple():
439438
# Run a simple query a few times.
@@ -462,8 +461,11 @@ async def check_simple():
462461
# operation.
463462
await check_simple()
464463

464+
@tb.with_connection_options(max_cached_statement_lifetime=142)
465465
async def test_prepare_24_max_lifetime(self):
466466
cache = self.con._stmt_cache
467+
468+
self.assertEqual(cache.get_max_lifetime(), 142)
467469
cache.set_max_lifetime(1)
468470

469471
s = await self.con.prepare('SELECT 1')
@@ -479,3 +481,35 @@ async def test_prepare_24_max_lifetime(self):
479481

480482
s = await self.con.prepare('SELECT 1')
481483
self.assertIsNot(s._state, state)
484+
485+
@tb.with_connection_options(max_cached_statement_lifetime=0.5)
486+
async def test_prepare_25_max_lifetime_reset(self):
487+
cache = self.con._stmt_cache
488+
489+
s = await self.con.prepare('SELECT 1')
490+
state = s._state
491+
492+
# Disable max_lifetime
493+
cache.set_max_lifetime(0)
494+
495+
await asyncio.sleep(1, loop=self.loop)
496+
497+
# The statement should still be cached (as we disabled the timeout).
498+
s = await self.con.prepare('SELECT 1')
499+
self.assertIs(s._state, state)
500+
501+
@tb.with_connection_options(max_cached_statement_lifetime=0.5)
502+
async def test_prepare_26_max_lifetime_max_size(self):
503+
cache = self.con._stmt_cache
504+
505+
s = await self.con.prepare('SELECT 1')
506+
state = s._state
507+
508+
# Disable max_lifetime
509+
cache.set_max_size(0)
510+
511+
s = await self.con.prepare('SELECT 1')
512+
self.assertIsNot(s._state, state)
513+
514+
# Check that nothing crashes after the initial timeout
515+
await asyncio.sleep(1, loop=self.loop)

tests/test_timeout.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,7 @@ async def test_invalid_timeout(self):
128128

129129
class TestConnectionCommandTimeout(tb.ConnectedTestCase):
130130

131-
def getExtraConnectOptions(self):
132-
return {
133-
'command_timeout': 0.2
134-
}
135-
131+
@tb.with_connection_options(command_timeout=0.2)
136132
async def test_command_timeout_01(self):
137133
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
138134
with self.assertRaises(asyncio.TimeoutError), \
@@ -151,12 +147,8 @@ async def _get_statement(self, query, timeout):
151147

152148
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):
153149

154-
def getExtraConnectOptions(self):
155-
return {
156-
'__connection_class__': SlowPrepareConnection,
157-
'command_timeout': 0.3
158-
}
159-
150+
@tb.with_connection_options(__connection_class__=SlowPrepareConnection,
151+
command_timeout=0.3)
160152
async def test_timeout_covers_prepare_01(self):
161153
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
162154
with self.assertRaises(asyncio.TimeoutError):

0 commit comments

Comments
 (0)