Skip to content

Commit 57c9ffd

Browse files
committed
Use the general statement cache for type introspection
Type introspection queries now rely on the general statement cache instead of ad-hoc prepared statements. Fixes: #198.
1 parent a8d871c commit 57c9ffd

File tree

2 files changed

+77
-33
lines changed

2 files changed

+77
-33
lines changed

asyncpg/connection.py

+28-32
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class Connection(metaclass=ConnectionMeta):
3838
Connections are created by calling :func:`~asyncpg.connection.connect`.
3939
"""
4040

41-
__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
42-
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
41+
__slots__ = ('_protocol', '_transport', '_loop',
42+
'_top_xact', '_uid', '_aborted',
4343
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
4444
'_listeners', '_server_version', '_server_caps',
4545
'_intro_query', '_reset_query', '_proxy',
@@ -53,8 +53,6 @@ def __init__(self, protocol, transport, loop,
5353
self._protocol = protocol
5454
self._transport = transport
5555
self._loop = loop
56-
self._types_stmt = None
57-
self._type_by_name_stmt = None
5856
self._top_xact = None
5957
self._uid = 0
6058
self._aborted = False
@@ -286,14 +284,17 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
286284
stmt_name = ''
287285

288286
statement = await self._protocol.prepare(stmt_name, query, timeout)
289-
290287
ready = statement._init_types()
291288
if ready is not True:
292-
if self._types_stmt is None:
293-
self._types_stmt = await self.prepare(self._intro_query)
294-
295-
types = await self._types_stmt.fetch(list(ready))
289+
types, intro_stmt = await self.__execute(
290+
self._intro_query, (list(ready),), 0, timeout)
296291
self._protocol.get_settings().register_data_types(types)
292+
if not intro_stmt.name and not statement.name:
293+
# The introspection query has used an anonymous statement,
294+
# which has blown away the anonymous statement we've prepared
295+
# for the query, so we need to re-prepare it.
296+
statement = await self._protocol.prepare(
297+
stmt_name, query, timeout)
297298

298299
if use_cache:
299300
self._stmt_cache.put(query, statement)
@@ -886,12 +887,8 @@ async def set_type_codec(self, typename, *,
886887
"asyncpg 0.13.0. Use the `format` keyword argument instead.",
887888
DeprecationWarning, stacklevel=2)
888889

889-
if self._type_by_name_stmt is None:
890-
self._type_by_name_stmt = await self.prepare(
891-
introspection.TYPE_BY_NAME)
892-
893-
typeinfo = await self._type_by_name_stmt.fetchrow(
894-
typename, schema)
890+
typeinfo = await self.fetchrow(
891+
introspection.TYPE_BY_NAME, typename, schema)
895892
if not typeinfo:
896893
raise ValueError('unknown type: {}.{}'.format(schema, typename))
897894

@@ -921,12 +918,8 @@ async def reset_type_codec(self, typename, *, schema='public'):
921918
.. versionadded:: 0.12.0
922919
"""
923920

924-
if self._type_by_name_stmt is None:
925-
self._type_by_name_stmt = await self.prepare(
926-
introspection.TYPE_BY_NAME)
927-
928-
typeinfo = await self._type_by_name_stmt.fetchrow(
929-
typename, schema)
921+
typeinfo = await self.fetchrow(
922+
introspection.TYPE_BY_NAME, typename, schema)
930923
if not typeinfo:
931924
raise ValueError('unknown type: {}.{}'.format(schema, typename))
932925

@@ -949,12 +942,8 @@ async def set_builtin_type_codec(self, typename, *,
949942
"""
950943
self._check_open()
951944

952-
if self._type_by_name_stmt is None:
953-
self._type_by_name_stmt = await self.prepare(
954-
introspection.TYPE_BY_NAME)
955-
956-
typeinfo = await self._type_by_name_stmt.fetchrow(
957-
typename, schema)
945+
typeinfo = await self.fetchrow(
946+
introspection.TYPE_BY_NAME, typename, schema)
958947
if not typeinfo:
959948
raise ValueError('unknown type: {}.{}'.format(schema, typename))
960949

@@ -1209,18 +1198,25 @@ def _drop_global_statement_cache(self):
12091198
self._drop_local_statement_cache()
12101199

12111200
async def _execute(self, query, args, limit, timeout, return_status=False):
1201+
with self._stmt_exclusive_section:
1202+
result, _ = await self.__execute(
1203+
query, args, limit, timeout, return_status=return_status)
1204+
return result
1205+
1206+
async def __execute(self, query, args, limit, timeout,
1207+
return_status=False):
12121208
executor = lambda stmt, timeout: self._protocol.bind_execute(
12131209
stmt, args, '', limit, return_status, timeout)
12141210
timeout = self._protocol._get_timeout(timeout)
1215-
with self._stmt_exclusive_section:
1216-
return await self._do_execute(query, executor, timeout)
1211+
return await self._do_execute(query, executor, timeout)
12171212

12181213
async def _executemany(self, query, args, timeout):
12191214
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
12201215
stmt, args, '', timeout)
12211216
timeout = self._protocol._get_timeout(timeout)
12221217
with self._stmt_exclusive_section:
1223-
return await self._do_execute(query, executor, timeout)
1218+
result, _ = await self._do_execute(query, executor, timeout)
1219+
return result
12241220

12251221
async def _do_execute(self, query, executor, timeout, retry=True):
12261222
if timeout is None:
@@ -1269,10 +1265,10 @@ async def _do_execute(self, query, executor, timeout, retry=True):
12691265
if self._protocol.is_in_transaction() or not retry:
12701266
raise
12711267
else:
1272-
result = await self._do_execute(
1268+
return await self._do_execute(
12731269
query, executor, timeout, retry=False)
12741270

1275-
return result
1271+
return result, stmt
12761272

12771273

12781274
async def connect(dsn=None, *,

tests/test_introspection.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
MAX_RUNTIME = 0.1
1212

1313

14-
class TestTimeout(tb.ConnectedTestCase):
14+
class TestIntrospection(tb.ConnectedTestCase):
1515
@classmethod
1616
def setUpClass(cls):
1717
super().setUpClass()
@@ -44,3 +44,51 @@ async def test_introspection_on_large_db(self):
4444

4545
with self.assertRunUnder(MAX_RUNTIME):
4646
await self.con.fetchval('SELECT $1::int[]', [1, 2])
47+
48+
@tb.with_connection_options(statement_cache_size=0)
49+
async def test_introspection_no_stmt_cache_01(self):
50+
self.assertEqual(self.con._stmt_cache.get_max_size(), 0)
51+
await self.con.fetchval('SELECT $1::int[]', [1, 2])
52+
53+
await self.con.execute('''
54+
CREATE EXTENSION IF NOT EXISTS hstore
55+
''')
56+
57+
try:
58+
await self.con.set_builtin_type_codec(
59+
'hstore', codec_name='pg_contrib.hstore')
60+
finally:
61+
await self.con.execute('''
62+
DROP EXTENSION hstore
63+
''')
64+
65+
self.assertEqual(self.con._uid, 0)
66+
67+
@tb.with_connection_options(max_cacheable_statement_size=1)
68+
async def test_introspection_no_stmt_cache_02(self):
69+
# max_cacheable_statement_size will disable caching both for
70+
# the user query and for the introspection query.
71+
await self.con.fetchval('SELECT $1::int[]', [1, 2])
72+
73+
await self.con.execute('''
74+
CREATE EXTENSION IF NOT EXISTS hstore
75+
''')
76+
77+
try:
78+
await self.con.set_builtin_type_codec(
79+
'hstore', codec_name='pg_contrib.hstore')
80+
finally:
81+
await self.con.execute('''
82+
DROP EXTENSION hstore
83+
''')
84+
85+
self.assertEqual(self.con._uid, 0)
86+
87+
@tb.with_connection_options(max_cacheable_statement_size=10000)
88+
async def test_introspection_no_stmt_cache_03(self):
89+
# max_cacheable_statement_size will disable caching for
90+
# the user query but not for the introspection query.
91+
await self.con.fetchval(
92+
"SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2])
93+
94+
self.assertEqual(self.con._uid, 1)

0 commit comments

Comments
 (0)