diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 6ac2a09d..79711c0c 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -756,6 +756,44 @@ async def fetchrow( return None return data[0] + async def fetchmany( + self, query, args, *, timeout: float=None, record_class=None + ): + """Run a query for each sequence of arguments in *args* + and return the results as a list of :class:`Record`. + + :param query: + Query to execute. + :param args: + An iterable containing sequences of arguments for the query. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by this method. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return list: + A list of :class:`~asyncpg.Record` instances. If specified, the + actual type of list elements would be *record_class*. + + Example: + + .. code-block:: pycon + + >>> rows = await con.fetchmany(''' + ... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a; + ... ''', [('x', 1), ('y', 2), ('z', 3)]) + >>> rows + [, , ] + + .. versionadded:: 0.30.0 + """ + self._check_open() + return await self._executemany( + query, args, timeout, return_rows=True, record_class=record_class + ) + async def copy_from_table(self, table_name, *, output, columns=None, schema_name=None, timeout=None, format=None, oids=None, delimiter=None, @@ -1896,17 +1934,27 @@ async def __execute( ) return result, stmt - async def _executemany(self, query, args, timeout): + async def _executemany( + self, + query, + args, + timeout, + return_rows=False, + record_class=None, + ): executor = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', timeout=timeout, + return_rows=return_rows, ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: with self._time_and_log(query, args, timeout): - result, _ = await self._do_execute(query, executor, timeout) + result, _ = await self._do_execute( + query, executor, timeout, record_class=record_class + ) return result async def _do_execute( diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 8a00d64b..19ced84b 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -609,6 +609,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None): record_class=record_class ) + async def fetchmany(self, query, args, *, timeout=None, record_class=None): + """Run a query for each sequence of arguments in *args* + and return the results as a list of :class:`Record`. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.fetchmany() + `. + + .. versionadded:: 0.30.0 + """ + async with self.acquire() as con: + return await con.fetchmany( + query, args, timeout=timeout, record_class=record_class + ) + async def copy_from_table( self, table_name, diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 195d0056..d66a5ad3 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -210,6 +210,27 @@ async def fetchrow(self, *args, timeout=None): return None return data[0] + @connresource.guarded + async def fetchmany(self, args, *, timeout=None): + """Execute the statement and return a list of :class:`Record` objects. + + :param args: Query arguments. + :param float timeout: Optional timeout value in seconds. + + :return: A list of :class:`Record` instances. + + .. versionadded:: 0.30.0 + """ + return await self.__do_execute( + lambda protocol: protocol.bind_execute_many( + self._state, + args, + portal_name='', + timeout=timeout, + return_rows=True, + ) + ) + @connresource.guarded async def executemany(self, args, *, timeout: float=None): """Execute the statement for each sequence of arguments in *args*. @@ -222,7 +243,12 @@ async def executemany(self, args, *, timeout: float=None): """ return await self.__do_execute( lambda protocol: protocol.bind_execute_many( - self._state, args, '', timeout)) + self._state, + args, + portal_name='', + timeout=timeout, + return_rows=False, + )) async def __do_execute(self, executor): protocol = self._connection._protocol diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index f6a0b08f..34c7c712 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -171,7 +171,7 @@ cdef class CoreProtocol: cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data) + object bind_data, bint return_rows) cdef bint _bind_execute_many_more(self, bint first=*) cdef _bind_execute_many_fail(self, object error, bint first=*) cdef _bind(self, str portal_name, str stmt_name, diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 4ef438cd..19857878 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -1020,12 +1020,12 @@ cdef class CoreProtocol: self._send_bind_message(portal_name, stmt_name, bind_data, limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data): + object bind_data, bint return_rows): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) - self.result = None - self._discard_data = True + self.result = [] if return_rows else None + self._discard_data = not return_rows self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 1459d908..bd2ad05c 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -212,6 +212,7 @@ cdef class BaseProtocol(CoreProtocol): args, portal_name: str, timeout, + return_rows: bool, ): if self.cancel_waiter is not None: await self.cancel_waiter @@ -237,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol): more = self._bind_execute_many( portal_name, state.name, - arg_bufs) # network op + arg_bufs, + return_rows) # network op self.last_query = state.query self.statement = state diff --git a/tests/test_execute.py b/tests/test_execute.py index 78d8c124..f8a0e43a 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -139,6 +139,45 @@ async def test_executemany_basic(self): ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) + async def test_executemany_returning(self): + result = await self.con.fetchmany(''' + INSERT INTO exmany VALUES($1, $2) RETURNING a, b + ''', [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Empty set + await self.con.fetchmany(''' + INSERT INTO exmany VALUES($1, $2) RETURNING a, b + ''', ()) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Without "RETURNING" + result = await self.con.fetchmany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [('e', 5), ('f', 6)]) + self.assertEqual(result, []) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6) + ]) + async def test_executemany_bad_input(self): with self.assertRaisesRegex( exceptions.DataError, diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 5911ccf2..661021bd 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -611,3 +611,17 @@ async def test_prepare_explicitly_named(self): 'prepared statement "foobar" already exists', ): await self.con.prepare('select 1', name='foobar') + + async def test_prepare_fetchmany(self): + tr = self.con.transaction() + await tr.start() + try: + await self.con.execute('CREATE TABLE fetchmany (a int, b text)') + + stmt = await self.con.prepare( + 'INSERT INTO fetchmany (a, b) VALUES ($1, $2) RETURNING a, b' + ) + result = await stmt.fetchmany([(1, 'a'), (2, 'b'), (3, 'c')]) + self.assertEqual(result, [(1, 'a'), (2, 'b'), (3, 'c')]) + finally: + await tr.rollback()