Skip to content

Commit

Permalink
Make it possible to specify a statement name in Connection.prepare() (#…
Browse files Browse the repository at this point in the history
…846)

This adds the new `name` keyword argument to `Connection.prepare()` and
`PreparedStatement.get_name()` method returning the name of a statement.

Some users of asyncpg might find it useful to be able to control how
prepared statements are named, especially when a custom prepared
statement caching scheme is in use.  Specifically, This should help with
pgbouncer support in SQLAlchemy asyncpg dialect.

Fixes: #837.
  • Loading branch information
elprans authored Nov 16, 2021
1 parent a2a9237 commit 03a3d18
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
27 changes: 22 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ async def _get_statement(
query,
timeout,
*,
named: bool=False,
use_cache: bool=True,
named=False,
use_cache=True,
ignore_custom_codec=False,
record_class=None
):
Expand All @@ -385,7 +385,9 @@ async def _get_statement(
len(query) > self._config.max_cacheable_statement_size):
use_cache = False

if use_cache or named:
if isinstance(named, str):
stmt_name = named
elif use_cache or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''
Expand Down Expand Up @@ -526,11 +528,21 @@ def cursor(
record_class,
)

async def prepare(self, query, *, timeout=None, record_class=None):
async def prepare(
self,
query,
*,
name=None,
timeout=None,
record_class=None,
):
"""Create a *prepared statement* for the specified query.
:param str query:
Text of the query to create a prepared statement for.
:param str name:
Optional name of the returned prepared statement. If not
specified, the name is auto-generated.
:param float timeout:
Optional timeout value in seconds.
:param type record_class:
Expand All @@ -544,9 +556,13 @@ async def prepare(self, query, *, timeout=None, record_class=None):
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.25.0
Added the *name* parameter.
"""
return await self._prepare(
query,
name=name,
timeout=timeout,
use_cache=False,
record_class=record_class,
Expand All @@ -556,6 +572,7 @@ async def _prepare(
self,
query,
*,
name=None,
timeout=None,
use_cache: bool=False,
record_class=None
Expand All @@ -564,7 +581,7 @@ async def _prepare(
stmt = await self._get_statement(
query,
timeout,
named=True,
named=True if name is None else name,
use_cache=use_cache,
record_class=record_class,
)
Expand Down
8 changes: 8 additions & 0 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def __init__(self, connection, query, state):
state.attach()
self._last_status = None

@connresource.guarded
def get_name(self) -> str:
"""Return the name of this prepared statement.
.. versionadded:: 0.25.0
"""
return self._state.name

@connresource.guarded
def get_query(self) -> str:
"""Return the text of the query for this prepared statement.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,14 @@ async def test_prepare_does_not_use_cache(self):
# prepare with disabled cache
await self.con.prepare('select 1')
self.assertEqual(len(cache), 0)

async def test_prepare_explicitly_named(self):
ps = await self.con.prepare('select 1', name='foobar')
self.assertEqual(ps.get_name(), 'foobar')
self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1)

with self.assertRaisesRegex(
exceptions.DuplicatePreparedStatementError,
'prepared statement "foobar" already exists',
):
await self.con.prepare('select 1', name='foobar')

0 comments on commit 03a3d18

Please sign in to comment.