@@ -41,10 +41,12 @@ class Connection(metaclass=ConnectionMeta):
4141 '_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
4242 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
4343 '_server_version' , '_server_caps' , '_intro_query' ,
44- '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
44+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
45+ '_max_cached_statement_use_count' )
4546
4647 def __init__ (self , protocol , transport , loop , addr , opts , * ,
47- statement_cache_size , command_timeout ):
48+ statement_cache_size , command_timeout ,
49+ max_cached_statement_use_count ):
4850 self ._protocol = protocol
4951 self ._transport = transport
5052 self ._loop = loop
@@ -60,6 +62,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6062 self ._stmt_cache_max_size = statement_cache_size
6163 self ._stmt_cache = collections .OrderedDict ()
6264 self ._stmts_to_close = set ()
65+ self ._max_cached_statement_use_count = max_cached_statement_use_count
6366
6467 if command_timeout is not None :
6568 try :
@@ -240,13 +243,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
240243 use_cache = self ._stmt_cache_max_size > 0
241244 if use_cache :
242245 try :
243- state = self ._stmt_cache [query ]
246+ holder = self ._stmt_cache [query ]
244247 except KeyError :
245248 pass
246249 else :
247- self ._stmt_cache .move_to_end (query , last = True )
248- if not state .closed :
249- return state
250+ if holder .use_count < self ._max_cached_statement_use_count :
251+ holder .use_count += 1
252+
253+ if holder .statement .closed :
254+ self ._stmt_cache .pop (query )
255+ else :
256+ self ._stmt_cache .move_to_end (query , last = True )
257+ return holder .statement
258+ else :
259+ self ._stmt_cache .pop (query )
250260
251261 protocol = self ._protocol
252262
@@ -255,9 +265,9 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
255265 else :
256266 stmt_name = ''
257267
258- state = await protocol .prepare (stmt_name , query , timeout )
268+ statement = await protocol .prepare (stmt_name , query , timeout )
259269
260- ready = state ._init_types ()
270+ ready = statement ._init_types ()
261271 if ready is not True :
262272 if self ._types_stmt is None :
263273 self ._types_stmt = await self .prepare (self ._intro_query )
@@ -267,16 +277,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
267277
268278 if use_cache :
269279 if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
270- old_query , old_state = self ._stmt_cache .popitem (last = False )
271- self ._maybe_gc_stmt (old_state )
272- self ._stmt_cache [query ] = state
280+ old_query , old_holder = self ._stmt_cache .popitem (last = False )
281+ self ._maybe_gc_stmt (old_holder . statement )
282+ self ._stmt_cache [query ] = _StatementCacheHolder ( statement )
273283
274284 # If we've just created a new statement object, check if there
275285 # are any statements for GC.
276286 if self ._stmts_to_close :
277287 await self ._cleanup_stmts ()
278288
279- return state
289+ return statement
280290
281291 def cursor (self , query , * args , prefetch = None , timeout = None ):
282292 """Return a *cursor factory* for the specified query.
@@ -442,8 +452,8 @@ def _get_unique_id(self, prefix):
442452 return '__asyncpg_{}_{}__' .format (prefix , self ._uid )
443453
444454 def _close_stmts (self ):
445- for stmt in self ._stmt_cache .values ():
446- stmt .mark_closed ()
455+ for holder in self ._stmt_cache .values ():
456+ holder . statement .mark_closed ()
447457
448458 for stmt in self ._stmts_to_close :
449459 stmt .mark_closed ()
@@ -657,6 +667,7 @@ async def connect(dsn=None, *,
657667 loop = None ,
658668 timeout = 60 ,
659669 statement_cache_size = 100 ,
670+ max_cached_statement_use_count = 100 ,
660671 command_timeout = None ,
661672 __connection_class__ = Connection ,
662673 ** opts ):
@@ -692,6 +703,10 @@ async def connect(dsn=None, *,
692703 :param float timeout: connection timeout in seconds.
693704
694705 :param int statement_cache_size: the size of prepared statement LRU cache.
706+ Pass ``0`` to disable the cache.
707+
708+ :param int max_cached_statement_use_count: max number of uses for a cached
709+ prepared statement.
695710
696711 :param float command_timeout: the default timeout for operations on
697712 this connection (the default is no timeout).
@@ -710,6 +725,9 @@ async def connect(dsn=None, *,
710725 ... print(types)
711726 >>> asyncio.get_event_loop().run_until_complete(run())
712727 [<Record typname='bool' typnamespace=11 ...
728+
729+ .. versionchanged:: 0.10.0
730+ Added ``max_cached_statement_use_count`` parameter.
713731 """
714732 if loop is None :
715733 loop = asyncio .get_event_loop ()
@@ -753,13 +771,24 @@ async def connect(dsn=None, *,
753771 tr .close ()
754772 raise
755773
756- con = __connection_class__ (pr , tr , loop , addr , opts ,
757- statement_cache_size = statement_cache_size ,
758- command_timeout = command_timeout )
774+ con = __connection_class__ (
775+ pr , tr , loop , addr , opts ,
776+ statement_cache_size = statement_cache_size ,
777+ max_cached_statement_use_count = max_cached_statement_use_count ,
778+ command_timeout = command_timeout )
779+
759780 pr .set_connection (con )
760781 return con
761782
762783
784+ class _StatementCacheHolder :
785+ __slots__ = ('statement' , 'use_count' )
786+
787+ def __init__ (self , statement ):
788+ self .use_count = 1
789+ self .statement = statement
790+
791+
763792class _Atomic :
764793 __slots__ = ('_acquired' ,)
765794
0 commit comments