diff --git a/src/lib.rs b/src/lib.rs index 6046455..ad3cce2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use pyo3::exceptions::PyIndexError; @@ -1129,6 +1130,16 @@ impl QueuePy { .all(|r| r.unwrap_or(false)) } + fn __hash__(&self, py: Python<'_>) -> PyResult { + let hash = PyModule::import(py, "builtins")?.getattr("hash")?; + let mut hasher = DefaultHasher::new(); + for each in &self.inner { + let n: i64 = hash.call1((each.into_py(py),))?.extract()?; + hasher.write_i64(n); + } + Ok(hasher.finish()) + } + fn __ne__(&self, other: &Self, py: Python<'_>) -> bool { (self.inner.len() != other.inner.len()) || self diff --git a/tests/test_queue.py b/tests/test_queue.py index 0ba8f81..4dfde8f 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -89,12 +89,6 @@ def test_repr(): assert str(Queue([1, 2, 3])) in "Queue([1, 2, 3])" -@pytest.mark.xfail(reason=HASH_MSG) -def test_hashing(): - assert hash(Queue([1, 2])) == hash(Queue([1, 2])) - assert hash(Queue([1, 2])) != hash(Queue([2, 1])) - - def test_sequence(): m = Queue("asdf") assert m == Queue(["a", "s", "d", "f"]) @@ -131,3 +125,15 @@ def test_more_eq(): assert not (Queue([o, o]) != Queue([o, o])) assert not (Queue([o]) != Queue([o])) assert not (Queue() != Queue([])) + + +def test_hashing(): + assert hash(Queue([1, 2])) == hash(Queue([1, 2])) + assert hash(Queue([1, 2])) != hash(Queue([2, 1])) + assert len({Queue([1, 2]), Queue([1, 2])}) == 1 + + +def test_unhashable_contents(): + q = Queue([1, {1}]) + with pytest.raises(TypeError): + hash(q)