From 7c93c4a7caf46f4479a89c81458f2b79006e23f1 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Mon, 7 Aug 2023 14:51:46 +0200 Subject: [PATCH] Implement DoubleEndedIterator for PyListIterator by caching the length while still validating it before access. --- newsfragments/3366.added.md | 2 +- src/types/list.rs | 75 ++++++++++++++++++++++++++++++++----- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/newsfragments/3366.added.md b/newsfragments/3366.added.md index ffc4eff8186..9a6a0799e22 100644 --- a/newsfragments/3366.added.md +++ b/newsfragments/3366.added.md @@ -1 +1 @@ -Add implementation `DoubleEndedIterator` for `PyTupleIterator`. +Add implementations `DoubleEndedIterator` for `PyTupleIterator` and `PyListIterator`. diff --git a/src/types/list.rs b/src/types/list.rs index 8f3a1672f6d..b5bbf70fade 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -1,4 +1,5 @@ use std::convert::TryInto; +use std::iter::FusedIterator; use crate::err::{self, PyResult}; use crate::ffi::{self, Py_ssize_t}; @@ -264,6 +265,7 @@ impl PyList { PyListIterator { list: self, index: 0, + length: self.len(), } } @@ -291,18 +293,28 @@ index_impls!(PyList, "list", PyList::len, PyList::get_slice); pub struct PyListIterator<'a> { list: &'a PyList, index: usize, + length: usize, +} + +impl<'a> PyListIterator<'a> { + unsafe fn get_item(&self, index: usize) -> &'a PyAny { + #[cfg(any(Py_LIMITED_API, PyPy))] + let item = self.list.get_item(index).expect("list.get failed"); + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + let item = self.list.get_item_unchecked(index); + item + } } impl<'a> Iterator for PyListIterator<'a> { type Item = &'a PyAny; #[inline] - fn next(&mut self) -> Option<&'a PyAny> { - if self.index < self.list.len() { - #[cfg(any(Py_LIMITED_API, PyPy))] - let item = self.list.get_item(self.index).expect("list.get failed"); - #[cfg(not(any(Py_LIMITED_API, PyPy)))] - let item = unsafe { self.list.get_item_unchecked(self.index) }; + fn next(&mut self) -> Option { + let length = self.length.min(self.list.len()); + + if self.index < length { + let item = unsafe { self.get_item(self.index) }; self.index += 1; Some(item) } else { @@ -317,13 +329,30 @@ impl<'a> Iterator for PyListIterator<'a> { } } +impl<'a> DoubleEndedIterator for PyListIterator<'a> { + #[inline] + fn next_back(&mut self) -> Option { + let length = self.length.min(self.list.len()); + + if self.index < length { + let item = unsafe { self.get_item(length - 1) }; + self.length = length - 1; + Some(item) + } else { + None + } + } +} + impl<'a> ExactSizeIterator for PyListIterator<'a> { fn len(&self) -> usize { - self.list.len().saturating_sub(self.index) + self.length.saturating_sub(self.index) } } -impl<'a> std::iter::IntoIterator for &'a PyList { +impl FusedIterator for PyListIterator<'_> {} + +impl<'a> IntoIterator for &'a PyList { type Item = &'a PyAny; type IntoIter = PyListIterator<'a>; @@ -494,13 +523,41 @@ mod tests { iter.next(); assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1))); - // Exhust iterator. + // Exhaust iterator. for _ in &mut iter {} assert_eq!(iter.size_hint(), (0, Some(0))); }); } + #[test] + fn test_iter_rev() { + Python::with_gil(|py| { + let v = vec![2, 3, 5, 7]; + let ob = v.to_object(py); + let list: &PyList = ob.downcast(py).unwrap(); + + let mut iter = list.iter().rev(); + + assert_eq!(iter.size_hint(), (4, Some(4))); + + assert_eq!(iter.next().unwrap().extract::().unwrap(), 7); + assert_eq!(iter.size_hint(), (3, Some(3))); + + assert_eq!(iter.next().unwrap().extract::().unwrap(), 5); + assert_eq!(iter.size_hint(), (2, Some(2))); + + assert_eq!(iter.next().unwrap().extract::().unwrap(), 3); + assert_eq!(iter.size_hint(), (1, Some(1))); + + assert_eq!(iter.next().unwrap().extract::().unwrap(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); + + assert!(iter.next().is_none()); + assert!(iter.next().is_none()); + }); + } + #[test] fn test_into_iter() { Python::with_gil(|py| {