Skip to content

Commit

Permalink
Implement Iterator for Bound<'py, PySequence>
Browse files Browse the repository at this point in the history
This allows using a `Bound<'py, PySequence>` directly in a for loop.
  • Loading branch information
LilyFoote committed Mar 2, 2024
1 parent 1c5265e commit cfed22f
Showing 1 changed file with 118 additions and 2 deletions.
120 changes: 118 additions & 2 deletions src/types/sequence.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter::FusedIterator;

use crate::err::{self, DowncastError, PyDowncastError, PyErr, PyResult};
use crate::exceptions::PyTypeError;
use crate::ffi_ptr_ext::FfiPtrExt;
Expand Down Expand Up @@ -287,6 +289,9 @@ pub trait PySequenceMethods<'py>: crate::sealed::Sealed {

/// Returns a fresh tuple based on the Sequence.
fn to_tuple(&self) -> PyResult<Bound<'py, PyTuple>>;

/// Returns an iterator over the Sequence's items.
fn iter(&self) -> BoundSequenceIterator<'py>;
}

impl<'py> PySequenceMethods<'py> for Bound<'py, PySequence> {
Expand Down Expand Up @@ -462,6 +467,100 @@ impl<'py> PySequenceMethods<'py> for Bound<'py, PySequence> {
.downcast_into_unchecked()
}
}

#[inline]
fn iter(&self) -> BoundSequenceIterator<'py> {
BoundSequenceIterator::new(self.clone())
}
}

pub struct BoundSequenceIterator<'py> {
sequence: Bound<'py, PySequence>,
index: usize,
length: usize,
}

impl<'py> BoundSequenceIterator<'py> {
fn new(sequence: Bound<'py, PySequence>) -> Self {
let length: usize = sequence.len().expect("failed to get sequence length");
Self {
sequence,
index: 0,
length,
}
}

unsafe fn get_item(&self, index: usize) -> Bound<'py, PyAny> {
self.sequence.get_item(index).expect("sequence.get failed")
}
}

impl<'py> Iterator for BoundSequenceIterator<'py> {
type Item = Bound<'py, PyAny>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let length = self
.length
.min(self.sequence.len().expect("failed to get sequence length"));

if self.index < length {
let item = unsafe { self.get_item(self.index) };
self.index += 1;
Some(item)
} else {
None
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}

impl DoubleEndedIterator for BoundSequenceIterator<'_> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
let length = self
.length
.min(self.sequence.len().expect("failed to get sequence length"));

if self.index < length {
let item = unsafe { self.get_item(length - 1) };
self.length = length - 1;
Some(item)
} else {
None
}
}
}

impl ExactSizeIterator for BoundSequenceIterator<'_> {
fn len(&self) -> usize {
self.length.saturating_sub(self.index)
}
}

impl FusedIterator for BoundSequenceIterator<'_> {}

impl<'py> IntoIterator for Bound<'py, PySequence> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundSequenceIterator<'py>;

fn into_iter(self) -> Self::IntoIter {
BoundSequenceIterator::new(self)
}
}

impl<'py> IntoIterator for &Bound<'py, PySequence> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundSequenceIterator<'py>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

#[inline]
Expand Down Expand Up @@ -509,8 +608,8 @@ where
};

let mut v = Vec::with_capacity(seq.len().unwrap_or(0));
for item in seq.iter()? {
v.push(item?.extract::<T>()?);
for item in seq.iter() {
v.push(item.extract::<T>()?);
}
Ok(v)
}
Expand Down Expand Up @@ -900,6 +999,23 @@ mod tests {
});
}

#[test]
fn test_seq_iter_bound() {
use crate::types::any::PyAnyMethods;

Python::with_gil(|py| {
let v: Vec<i32> = vec![1, 1, 2, 3, 5, 8];
let ob = v.to_object(py);
let seq = ob.downcast_bound::<PySequence>(py).unwrap();
let mut idx = 0;
for el in seq {
assert_eq!(v[idx], el.extract::<i32>().unwrap());
idx += 1;
}
assert_eq!(idx, v.len());
});
}

#[test]
fn test_seq_strings() {
Python::with_gil(|py| {
Expand Down

0 comments on commit cfed22f

Please sign in to comment.