diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 303abe7059..4d6140750e 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -1892,8 +1892,16 @@ async def next(self) -> AsyncGridOut: next_file = await super().next() return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session) - async def to_list(self) -> list[AsyncGridOut]: - return [x async for x in self] # noqa: C416,RUF100 + async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x async for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(await self.next()) + return ret __anext__ = next diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 1e3d265d4b..bc2e29a61d 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -1878,8 +1878,16 @@ def next(self) -> GridOut: next_file = super().next() return GridOut(self._root_collection, file_document=next_file, session=self.session) - def to_list(self) -> list[GridOut]: - return [x for x in self] # noqa: C416,RUF100 + def to_list(self, length: Optional[int] = None) -> list[GridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(self.next()) + return ret __next__ = next diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index b28f983b12..b2cd345f63 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -346,13 +346,17 @@ async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]: else: return None - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: await self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]: async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8421667bec..bae77bb304 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1260,16 +1260,20 @@ async def next(self) -> _DocumentType: else: raise StopAsyncIteration - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True await self._supports_exhaust() if self._empty: return False if len(self._data) or await self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1286,21 +1290,32 @@ async def __aenter__(self) -> AsyncCursor[_DocumentType]: async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 86fa69dcb6..da05bf1a3b 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -346,13 +346,17 @@ def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]: else: return None - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ def __enter__(self) -> CommandCursor[_DocumentType]: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 1595ce40b9..c352b64098 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1258,16 +1258,20 @@ def next(self) -> _DocumentType: else: raise StopIteration - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True self._supports_exhaust() if self._empty: return False if len(self._data) or self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1284,21 +1288,32 @@ def __enter__(self) -> Cursor[_DocumentType]: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d6d56244f7..0b6effc19b 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1401,6 +1401,20 @@ async def test_to_list_empty(self): docs = await c.to_list() self.assertEqual([], docs) + async def test_to_list_length(self): + coll = self.db.test + await coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + docs = await c.to_list(3) + self.assertEqual(len(docs), 2) + @async_client_context.require_change_streams async def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1417,6 +1431,19 @@ async def test_command_cursor_to_list_empty(self): docs = await c.to_list() self.assertEqual([], docs) + @async_client_context.require_change_streams + async def test_command_cursor_to_list_length(self): + db = self.db + await db.drop_collection("test") + await db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list()), 2) + + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list(1)), 1) + class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): diff --git a/test/test_cursor.py b/test/test_cursor.py index 0d61865196..5d5e17d128 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1392,6 +1392,20 @@ def test_to_list_empty(self): docs = c.to_list() self.assertEqual([], docs) + def test_to_list_length(self): + coll = self.db.test + coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + docs = c.to_list(3) + self.assertEqual(len(docs), 2) + @client_context.require_change_streams def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1408,6 +1422,19 @@ def test_command_cursor_to_list_empty(self): docs = c.to_list() self.assertEqual([], docs) + @client_context.require_change_streams + def test_command_cursor_to_list_length(self): + db = self.db + db.drop_collection("test") + db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list()), 2) + + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list(1)), 1) + class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 27b38dc0b0..19ec152bd1 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -440,6 +440,12 @@ def test_gridfs_find(self): gout = next(cursor) self.assertEqual(b"test2+", gout.read()) self.assertRaises(StopIteration, cursor.__next__) + cursor.rewind() + items = cursor.to_list() + self.assertEqual(len(items), 2) + cursor.rewind() + items = cursor.to_list(1) + self.assertEqual(len(items), 1) cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})