From e3ec052ecc6801449550e712962b8d927e3e880f Mon Sep 17 00:00:00 2001 From: Sebastian Puetz Date: Fri, 25 Oct 2019 12:12:05 +0200 Subject: [PATCH] Remove contains and iter from PyMappingProtocol. The methods are not expected by CPython and are only explicitly callable. To get iteration support, PyIterProtocol should be implemented and to get support for `x in mapping`, PySequenceProtocol's __contains__ should be implemented. https://github.com/PyO3/pyo3/issues/611 --- src/class/mapping.rs | 58 ------------------- tests/test_mapping.rs | 130 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 58 deletions(-) create mode 100644 tests/test_mapping.rs diff --git a/src/class/mapping.rs b/src/class/mapping.rs index 32326ff465b..f1c4b3ffb0b 100644 --- a/src/class/mapping.rs +++ b/src/class/mapping.rs @@ -43,20 +43,6 @@ pub trait PyMappingProtocol<'p>: PyTypeInfo { unimplemented!() } - fn __iter__(&'p self, py: Python<'p>) -> Self::Result - where - Self: PyMappingIterProtocol<'p>, - { - unimplemented!() - } - - fn __contains__(&'p self, value: Self::Value) -> Self::Result - where - Self: PyMappingContainsProtocol<'p>, - { - unimplemented!() - } - fn __reversed__(&'p self) -> Self::Result where Self: PyMappingReversedProtocol<'p>, @@ -89,16 +75,6 @@ pub trait PyMappingDelItemProtocol<'p>: PyMappingProtocol<'p> { type Result: Into>; } -pub trait PyMappingIterProtocol<'p>: PyMappingProtocol<'p> { - type Success: IntoPy; - type Result: Into>; -} - -pub trait PyMappingContainsProtocol<'p>: PyMappingProtocol<'p> { - type Value: FromPyObject<'p>; - type Result: Into>; -} - pub trait PyMappingReversedProtocol<'p>: PyMappingProtocol<'p> { type Success: IntoPy; type Result: Into>; @@ -142,12 +118,6 @@ where fn methods() -> Vec { let mut methods = Vec::new(); - if let Some(def) = ::__iter__() { - methods.push(def) - } - if let Some(def) = ::__contains__() { - methods.push(def) - } if let Some(def) = ::__reversed__() { methods.push(def) } @@ -283,20 +253,6 @@ where } } -#[doc(hidden)] -pub trait PyMappingContainsProtocolImpl { - fn __contains__() -> Option; -} - -impl<'p, T> PyMappingContainsProtocolImpl for T -where - T: PyMappingProtocol<'p>, -{ - default fn __contains__() -> Option { - None - } -} - #[doc(hidden)] pub trait PyMappingReversedProtocolImpl { fn __reversed__() -> Option; @@ -310,17 +266,3 @@ where None } } - -#[doc(hidden)] -pub trait PyMappingIterProtocolImpl { - fn __iter__() -> Option; -} - -impl<'p, T> PyMappingIterProtocolImpl for T -where - T: PyMappingProtocol<'p>, -{ - default fn __iter__() -> Option { - None - } -} diff --git a/tests/test_mapping.rs b/tests/test_mapping.rs new file mode 100644 index 00000000000..6160da0fc66 --- /dev/null +++ b/tests/test_mapping.rs @@ -0,0 +1,130 @@ +#![feature(specialization)] +use std::collections::HashMap; + +use pyo3::exceptions::KeyError; +use pyo3::prelude::*; +use pyo3::types::IntoPyDict; +use pyo3::types::PyList; +use pyo3::PyMappingProtocol; + +#[pyclass] +struct Mapping { + index: HashMap, +} + +#[pymethods] +impl Mapping { + #[new] + fn new(obj: &PyRawObject, elements: Option<&PyList>) -> PyResult<()> { + if let Some(pylist) = elements { + let mut elems = HashMap::with_capacity(pylist.len()); + for (i, pyelem) in pylist.into_iter().enumerate() { + let elem = String::extract(pyelem)?; + elems.insert(elem, i); + } + obj.init(Self { index: elems }); + } else { + obj.init(Self { + index: HashMap::new(), + }); + } + Ok(()) + } +} + +#[pyproto] +impl PyMappingProtocol for Mapping { + fn __len__(&self) -> PyResult { + Ok(self.index.len()) + } + + fn __getitem__(&self, query: String) -> PyResult { + self.index + .get(&query) + .copied() + .ok_or_else(|| KeyError::py_err("unknown key")) + } + + fn __setitem__(&mut self, key: String, value: usize) -> PyResult<()> { + self.index.insert(key, value); + Ok(()) + } + + fn __delitem__(&mut self, key: String) -> PyResult<()> { + if self.index.remove(&key).is_none() { + KeyError::py_err("unknown key").into() + } else { + Ok(()) + } + } + + /// not an actual reversed implementation, just to demonstrate that the method is callable. + fn __reversed__(&self) -> PyResult { + let gil = Python::acquire_gil(); + Ok(self + .index + .keys() + .cloned() + .collect::>() + .into_py(gil.python())) + } +} + +#[test] +fn test_getitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let d = [("Mapping", py.get_type::())].into_py_dict(py); + + let run = |code| py.run(code, None, Some(d)).unwrap(); + let err = |code| py.run(code, None, Some(d)).unwrap_err(); + + run("m = Mapping(['1', '2', '3']); assert m['1'] == 0"); + run("m = Mapping(['1', '2', '3']); assert m['2'] == 1"); + run("m = Mapping(['1', '2', '3']); assert m['3'] == 2"); + err("m = Mapping(['1', '2', '3']); print(m['4'])"); +} + +#[test] +fn test_setitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let d = [("Mapping", py.get_type::())].into_py_dict(py); + + let run = |code| py.run(code, None, Some(d)).unwrap(); + let err = |code| py.run(code, None, Some(d)).unwrap_err(); + + run("m = Mapping(['1', '2', '3']); m['1'] = 4; assert m['1'] == 4"); + run("m = Mapping(['1', '2', '3']); m['0'] = 0; assert m['0'] == 0"); + run("m = Mapping(['1', '2', '3']); len(m) == 4"); + err("m = Mapping(['1', '2', '3']); m[0] = 'hello'"); + err("m = Mapping(['1', '2', '3']); m[0] = -1"); +} + +#[test] +fn test_delitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let d = [("Mapping", py.get_type::())].into_py_dict(py); + let run = |code| py.run(code, None, Some(d)).unwrap(); + let err = |code| py.run(code, None, Some(d)).unwrap_err(); + + run( + "m = Mapping(['1', '2', '3']); del m['1']; assert len(m) == 2; \ + assert m['2'] == 1; assert m['3'] == 2", + ); + err("m = Mapping(['1', '2', '3']); del m[-1]"); + err("m = Mapping(['1', '2', '3']); del m['4']"); +} + +#[test] +fn test_reversed() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let d = [("Mapping", py.get_type::())].into_py_dict(py); + let run = |code| py.run(code, None, Some(d)).unwrap(); + + run("m = Mapping(['1', '2']); assert set(reversed(m)) == {'1', '2'}"); +}