Skip to content

Commit

Permalink
Add IntoIterator for &Bound types (#3923)
Browse files Browse the repository at this point in the history
* Add IntoIterator for &Bound<'py, PyList>

* Add a test for Bound<'_, PyList>.into_iter

* Implement IntoIterator for more &Bound types

* Remove some explicit .iter() calls

* Implement IntoIterator for &Bound<'py, PyIterator>
  • Loading branch information
LilyFoote authored Mar 4, 2024
1 parent 4114dcb commit 811a3e5
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyo3-benches/benches/bench_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn iter_dict(b: &mut Bencher<'_>) {
let dict = (0..LEN as u64).map(|i| (i, i * 2)).into_py_dict_bound(py);
let mut sum = 0;
b.iter(|| {
for (k, _v) in dict.iter() {
for (k, _v) in &dict {
let i: u64 = k.extract().unwrap();
sum += i;
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-benches/benches/bench_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn iter_list(b: &mut Bencher<'_>) {
let list = PyList::new_bound(py, 0..LEN);
let mut sum = 0;
b.iter(|| {
for x in list.iter() {
for x in &list {
let i: u64 = x.extract().unwrap();
sum += i;
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-benches/benches/bench_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn iter_set(b: &mut Bencher<'_>) {
let set = PySet::new_bound(py, &(0..LEN).collect::<Vec<_>>()).unwrap();
let mut sum = 0;
b.iter(|| {
for x in set.iter() {
for x in &set {
let i: u64 = x.extract().unwrap();
sum += i;
}
Expand Down
2 changes: 1 addition & 1 deletion src/conversions/hashbrown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
fn extract_bound(ob: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let dict = ob.downcast::<PyDict>()?;
let mut ret = hashbrown::HashMap::with_capacity_and_hasher(dict.len(), S::default());
for (k, v) in dict.iter() {
for (k, v) in dict {
ret.insert(k.extract()?, v.extract()?);
}
Ok(ret)
Expand Down
2 changes: 1 addition & 1 deletion src/conversions/indexmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ where
fn extract_bound(ob: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let dict = ob.downcast::<PyDict>()?;
let mut ret = indexmap::IndexMap::with_capacity_and_hasher(dict.len(), S::default());
for (k, v) in dict.iter() {
for (k, v) in dict {
ret.insert(k.extract()?, v.extract()?);
}
Ok(ret)
Expand Down
4 changes: 2 additions & 2 deletions src/conversions/std/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ where
fn extract_bound(ob: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let dict = ob.downcast::<PyDict>()?;
let mut ret = collections::HashMap::with_capacity_and_hasher(dict.len(), S::default());
for (k, v) in dict.iter() {
for (k, v) in dict {
ret.insert(k.extract()?, v.extract()?);
}
Ok(ret)
Expand All @@ -96,7 +96,7 @@ where
fn extract_bound(ob: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let dict = ob.downcast::<PyDict>()?;
let mut ret = collections::BTreeMap::new();
for (k, v) in dict.iter() {
for (k, v) in dict {
ret.insert(k.extract()?, v.extract()?);
}
Ok(ret)
Expand Down
29 changes: 29 additions & 0 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,15 @@ impl<'py> IntoIterator for Bound<'py, PyDict> {
}
}

impl<'py> IntoIterator for &Bound<'py, PyDict> {
type Item = (Bound<'py, PyAny>, Bound<'py, PyAny>);
type IntoIter = BoundDictIterator<'py>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

mod borrowed_iter {
use super::*;

Expand Down Expand Up @@ -1123,6 +1132,26 @@ mod tests {
});
}

#[test]
fn test_iter_bound() {
Python::with_gil(|py| {
let mut v = HashMap::new();
v.insert(7, 32);
v.insert(8, 42);
v.insert(9, 123);
let ob = v.to_object(py);
let dict: &Bound<'_, PyDict> = ob.downcast_bound(py).unwrap();
let mut key_sum = 0;
let mut value_sum = 0;
for (key, value) in dict {
key_sum += key.extract::<i32>().unwrap();
value_sum += value.extract::<i32>().unwrap();
}
assert_eq!(7 + 8 + 9, key_sum);
assert_eq!(32 + 42 + 123, value_sum);
});
}

#[test]
fn test_iter_value_mutated() {
Python::with_gil(|py| {
Expand Down
21 changes: 21 additions & 0 deletions src/types/frozenset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ impl<'py> IntoIterator for Bound<'py, PyFrozenSet> {
}
}

impl<'py> IntoIterator for &Bound<'py, PyFrozenSet> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundFrozenSetIterator<'py>;

/// Returns an iterator of values in this set.
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

/// PyO3 implementation of an iterator for a Python `frozenset` object.
pub struct BoundFrozenSetIterator<'p> {
it: Bound<'p, PyIterator>,
Expand Down Expand Up @@ -357,6 +367,17 @@ mod tests {
});
}

#[test]
fn test_frozenset_iter_bound() {
Python::with_gil(|py| {
let set = PyFrozenSet::new_bound(py, &[1]).unwrap();

for el in &set {
assert_eq!(1i32, el.extract::<i32>().unwrap());
}
});
}

#[test]
fn test_frozenset_iter_size_hint() {
Python::with_gil(|py| {
Expand Down
42 changes: 42 additions & 0 deletions src/types/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ impl<'py> Borrowed<'_, 'py, PyIterator> {
}
}

impl<'py> IntoIterator for &Bound<'py, PyIterator> {
type Item = PyResult<Bound<'py, PyAny>>;
type IntoIter = Bound<'py, PyIterator>;

fn into_iter(self) -> Self::IntoIter {
self.clone()
}
}

impl PyTypeCheck for PyIterator {
const NAME: &'static str = "Iterator";

Expand Down Expand Up @@ -246,6 +255,39 @@ def fibonacci(target):
});
}

#[test]
fn fibonacci_generator_bound() {
use crate::types::any::PyAnyMethods;
use crate::Bound;

let fibonacci_generator = r#"
def fibonacci(target):
a = 1
b = 1
for _ in range(target):
yield a
a, b = b, a + b
"#;

Python::with_gil(|py| {
let context = PyDict::new_bound(py);
py.run_bound(fibonacci_generator, None, Some(&context))
.unwrap();

let generator: Bound<'_, PyIterator> = py
.eval_bound("fibonacci(5)", None, Some(&context))
.unwrap()
.downcast_into()
.unwrap();
let mut items = vec![];
for actual in &generator {
let actual = actual.unwrap().extract::<usize>().unwrap();
items.push(actual);
}
assert_eq!(items, [1, 1, 2, 3, 5]);
});
}

#[test]
fn int_not_iterable() {
Python::with_gil(|py| {
Expand Down
23 changes: 23 additions & 0 deletions src/types/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,15 @@ impl<'py> IntoIterator for Bound<'py, PyList> {
}
}

impl<'py> IntoIterator for &Bound<'py, PyList> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundListIterator<'py>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

#[cfg(test)]
#[cfg_attr(not(feature = "gil-refs"), allow(deprecated))]
mod tests {
Expand Down Expand Up @@ -911,6 +920,20 @@ mod tests {
});
}

#[test]
fn test_into_iter_bound() {
use crate::types::any::PyAnyMethods;

Python::with_gil(|py| {
let list = PyList::new_bound(py, [1, 2, 3, 4]);
let mut items = vec![];
for item in &list {
items.push(item.extract::<i32>().unwrap());
}
assert_eq!(items, vec![1, 2, 3, 4]);
});
}

#[test]
fn test_extract() {
Python::with_gil(|py| {
Expand Down
2 changes: 1 addition & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub use self::typeobject::{PyType, PyTypeMethods};
/// Python::with_gil(|py| {
/// let dict = py.eval_bound("{'a':'b', 'c':'d'}", None, None)?.downcast_into::<PyDict>()?;
///
/// for (key, value) in dict.iter() {
/// for (key, value) in &dict {
/// println!("key: {}, value: {}", key, value);
/// }
///
Expand Down
27 changes: 27 additions & 0 deletions src/types/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ impl<'py> IntoIterator for Bound<'py, PySet> {
}
}

impl<'py> IntoIterator for &Bound<'py, PySet> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundSetIterator<'py>;

/// Returns an iterator of values in this set.
///
/// # Panics
///
/// If PyO3 detects that the set is mutated during iteration, it will panic.
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

/// PyO3 implementation of an iterator for a Python `set` object.
pub struct BoundSetIterator<'p> {
it: Bound<'p, PyIterator>,
Expand Down Expand Up @@ -482,6 +496,19 @@ mod tests {
});
}

#[test]
fn test_set_iter_bound() {
use crate::types::any::PyAnyMethods;

Python::with_gil(|py| {
let set = PySet::new_bound(py, &[1]).unwrap();

for el in &set {
assert_eq!(1i32, el.extract::<i32>().unwrap());
}
});
}

#[test]
#[should_panic]
fn test_set_iter_mutation() {
Expand Down
26 changes: 26 additions & 0 deletions src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ impl<'py> IntoIterator for Bound<'py, PyTuple> {
}
}

impl<'py> IntoIterator for &Bound<'py, PyTuple> {
type Item = Bound<'py, PyAny>;
type IntoIter = BoundTupleIterator<'py>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

/// Used by `PyTuple::iter_borrowed()`.
pub struct BorrowedTupleIterator<'a, 'py> {
tuple: Borrowed<'a, 'py, PyTuple>,
Expand Down Expand Up @@ -976,6 +985,23 @@ mod tests {
});
}

#[test]
fn test_into_iter_bound() {
use crate::Bound;

Python::with_gil(|py| {
let ob = (1, 2, 3).to_object(py);
let tuple: &Bound<'_, PyTuple> = ob.downcast_bound(py).unwrap();
assert_eq!(3, tuple.len());

let mut items = vec![];
for item in tuple {
items.push(item.extract::<usize>().unwrap());
}
assert_eq!(items, vec![1, 2, 3]);
});
}

#[test]
#[cfg(not(Py_LIMITED_API))]
fn test_as_slice() {
Expand Down

0 comments on commit 811a3e5

Please sign in to comment.