Skip to content

Commit

Permalink
Add fetchmany to execute many *and* return rows (#1175)
Browse files Browse the repository at this point in the history

Co-authored-by: Elvis Pranskevichus <elvis@edgedb.com>
  • Loading branch information
rossmacarthur and elprans authored Oct 18, 2024
1 parent b732b4f commit 73f2209
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 8 deletions.
52 changes: 50 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,44 @@ async def fetchrow(
return None
return data[0]

async def fetchmany(
self, query, args, *, timeout: float=None, record_class=None
):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.
:param query:
Query to execute.
:param args:
An iterable containing sequences of arguments for the query.
:param float timeout:
Optional timeout value in seconds.
:param type record_class:
If specified, the class to use for records returned by this method.
Must be a subclass of :class:`~asyncpg.Record`. If not specified,
a per-connection *record_class* is used.
:return list:
A list of :class:`~asyncpg.Record` instances. If specified, the
actual type of list elements would be *record_class*.
Example:
.. code-block:: pycon
>>> rows = await con.fetchmany('''
... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
... ''', [('x', 1), ('y', 2), ('z', 3)])
>>> rows
[<Record row=('x',)>, <Record row=('y',)>, <Record row=('z',)>]
.. versionadded:: 0.30.0
"""
self._check_open()
return await self._executemany(
query, args, timeout, return_rows=True, record_class=record_class
)

async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
Expand Down Expand Up @@ -1896,17 +1934,27 @@ async def __execute(
)
return result, stmt

async def _executemany(self, query, args, timeout):
async def _executemany(
self,
query,
args,
timeout,
return_rows=False,
record_class=None,
):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
state=stmt,
args=args,
portal_name='',
timeout=timeout,
return_rows=return_rows,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
with self._time_and_log(query, args, timeout):
result, _ = await self._do_execute(query, executor, timeout)
result, _ = await self._do_execute(
query, executor, timeout, record_class=record_class
)
return result

async def _do_execute(
Expand Down
16 changes: 16 additions & 0 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None):
record_class=record_class
)

async def fetchmany(self, query, args, *, timeout=None, record_class=None):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
:meth:`Connection.fetchmany()
<asyncpg.connection.Connection.fetchmany>`.
.. versionadded:: 0.30.0
"""
async with self.acquire() as con:
return await con.fetchmany(
query, args, timeout=timeout, record_class=record_class
)

async def copy_from_table(
self,
table_name,
Expand Down
28 changes: 27 additions & 1 deletion asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]

@connresource.guarded
async def fetchmany(self, args, *, timeout=None):
"""Execute the statement and return a list of :class:`Record` objects.
:param args: Query arguments.
:param float timeout: Optional timeout value in seconds.
:return: A list of :class:`Record` instances.
.. versionadded:: 0.30.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=True,
)
)

@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
Expand All @@ -222,7 +243,12 @@ async def executemany(self, args, *, timeout: float=None):
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=False,
))

async def __do_execute(self, executor):
protocol = self._connection._protocol
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
object bind_data, bint return_rows)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1020,12 +1020,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):
object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self.result = [] if return_rows else None
self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
Expand Down
4 changes: 3 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand All @@ -237,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op
arg_bufs,
return_rows) # network op

self.last_query = state.query
self.statement = state
Expand Down
39 changes: 39 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,45 @@ async def test_executemany_basic(self):
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

async def test_executemany_returning(self):
result = await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', ())
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Without "RETURNING"
result = await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2)
''', [('e', 5), ('f', 6)])
self.assertEqual(result, [])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,17 @@ async def test_prepare_explicitly_named(self):
'prepared statement "foobar" already exists',
):
await self.con.prepare('select 1', name='foobar')

async def test_prepare_fetchmany(self):
tr = self.con.transaction()
await tr.start()
try:
await self.con.execute('CREATE TABLE fetchmany (a int, b text)')

stmt = await self.con.prepare(
'INSERT INTO fetchmany (a, b) VALUES ($1, $2) RETURNING a, b'
)
result = await stmt.fetchmany([(1, 'a'), (2, 'b'), (3, 'c')])
self.assertEqual(result, [(1, 'a'), (2, 'b'), (3, 'c')])
finally:
await tr.rollback()

0 comments on commit 73f2209

Please sign in to comment.