diff --git a/src/types/dict.rs b/src/types/dict.rs index 1c8550e4344..732f4e4d8a3 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -642,7 +642,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::types::PyTuple; + use crate::tests::common::generate_unique_module_name; + use crate::types::{PyModule, PyTuple}; + use pyo3_ffi::c_str; use std::collections::{BTreeMap, HashMap}; #[test] @@ -965,6 +967,45 @@ mod tests { }); } + #[test] + fn test_iter_subclass() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let dict = ob.downcast_bound::(py).unwrap(); + + let module = PyModule::from_code( + py, + c_str!("class DictSubclass(dict): pass"), + c_str!("test.py"), + &generate_unique_module_name("test"), + ) + .unwrap(); + + let subclass = module + .getattr("DictSubclass") + .unwrap() + .call1((dict,)) + .unwrap() + .downcast_into::() + .unwrap(); + + let mut key_sum = 0; + let mut value_sum = 0; + let iter = subclass.iter(); + assert!(matches!(iter, BoundDictIterator::ItemIter { .. })); + for (key, value) in iter { + key_sum += key.extract::().unwrap(); + value_sum += value.extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + assert_eq!(32 + 42 + 123, value_sum); + }); + } + #[test] fn test_iter_bound() { Python::with_gil(|py| {