Skip to content

Commit d17dbd1

Browse files
committed
Add rpds's Stack.
1 parent 74707af commit d17dbd1

File tree

2 files changed

+302
-0
lines changed

2 files changed

+302
-0
lines changed

src/lib.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use pyo3::{exceptions::PyKeyError, types::PyMapping, types::PyTupleMethods};
55
use pyo3::{prelude::*, BoundObject, PyTypeInfo};
66
use rpds::{
77
HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync, Queue, QueueSync,
8+
Stack, StackSync,
89
};
910
use std::collections::hash_map::DefaultHasher;
1011
use std::hash::{Hash, Hasher};
@@ -1238,6 +1239,155 @@ impl ListIterator {
12381239
}
12391240
}
12401241

1242+
#[repr(transparent)]
1243+
#[pyclass(name = "Stack", module = "rpds", frozen, sequence)]
1244+
struct StackPy {
1245+
inner: StackSync<Py<PyAny>>,
1246+
}
1247+
1248+
impl From<StackSync<Py<PyAny>>> for StackPy {
1249+
fn from(elements: StackSync<Py<PyAny>>) -> Self {
1250+
StackPy { inner: elements }
1251+
}
1252+
}
1253+
1254+
#[pymethods]
1255+
impl StackPy {
1256+
#[new]
1257+
#[pyo3(signature = (*args))]
1258+
fn init(args: &Bound<'_, PyTuple>) -> PyResult<Self> {
1259+
let mut inner = Stack::new_sync();
1260+
if args.len() == 1 {
1261+
for each in args.get_item(0)?.try_iter()? {
1262+
inner.push_mut(each?.extract()?);
1263+
}
1264+
} else {
1265+
for each in args {
1266+
inner.push_mut(each.extract()?);
1267+
}
1268+
}
1269+
Ok(StackPy { inner })
1270+
}
1271+
1272+
fn __hash__(&self, py: Python<'_>) -> PyResult<u64> {
1273+
let mut hasher = DefaultHasher::new();
1274+
1275+
self.inner
1276+
.iter()
1277+
.enumerate()
1278+
.try_for_each(|(index, each)| {
1279+
each.bind(py)
1280+
.hash()
1281+
.map_err(|_| {
1282+
PyTypeError::new_err(format!(
1283+
"Unhashable type at {} element in Stack: {}",
1284+
index,
1285+
each.bind(py)
1286+
.repr()
1287+
.and_then(|r| r.extract())
1288+
.unwrap_or("<repr> error".to_string())
1289+
))
1290+
})
1291+
.map(|x| hasher.write_isize(x))
1292+
})?;
1293+
1294+
Ok(hasher.finish())
1295+
}
1296+
1297+
fn __iter__(slf: PyRef<'_, Self>) -> StackIterator {
1298+
StackIterator {
1299+
inner: slf.inner.clone(),
1300+
}
1301+
}
1302+
1303+
fn __len__(&self) -> usize {
1304+
self.inner.size()
1305+
}
1306+
1307+
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<Py<PyAny>> {
1308+
match op {
1309+
CompareOp::Eq => (self.inner.size() == other.inner.size()
1310+
&& self
1311+
.inner
1312+
.iter()
1313+
.zip(other.inner.iter())
1314+
.map(|(e1, e2)| e1.bind(py).eq(e2))
1315+
.all(|r| r.unwrap_or(false)))
1316+
.into_pyobject(py)
1317+
.map_err(Into::into)
1318+
.map(BoundObject::into_any)
1319+
.map(BoundObject::unbind),
1320+
CompareOp::Ne => (self.inner.size() != other.inner.size()
1321+
|| self
1322+
.inner
1323+
.iter()
1324+
.zip(other.inner.iter())
1325+
.map(|(e1, e2)| e1.bind(py).ne(e2))
1326+
.any(|r| r.unwrap_or(true)))
1327+
.into_pyobject(py)
1328+
.map_err(Into::into)
1329+
.map(BoundObject::into_any)
1330+
.map(BoundObject::unbind),
1331+
_ => Ok(py.NotImplemented()),
1332+
}
1333+
}
1334+
1335+
fn __repr__(&self, py: Python) -> PyResult<String> {
1336+
let contents = self.inner.into_iter().map(|k| {
1337+
Ok(k.into_pyobject(py)?
1338+
.call_method0("__repr__")
1339+
.and_then(|r| r.extract())
1340+
.unwrap_or("<repr failed>".to_owned()))
1341+
});
1342+
let mut contents = contents.collect::<Result<Vec<_>, PyErr>>()?;
1343+
contents.reverse();
1344+
Ok(format!("Stack([{}])", contents.join(", ")))
1345+
}
1346+
1347+
fn peek(&self, py: Python) -> PyResult<Py<PyAny>> {
1348+
if let Some(peeked) = self.inner.peek() {
1349+
Ok(peeked.clone_ref(py))
1350+
} else {
1351+
Err(PyIndexError::new_err("peeked an empty stack"))
1352+
}
1353+
}
1354+
1355+
fn pop(&self) -> PyResult<StackPy> {
1356+
if let Some(popped) = self.inner.pop() {
1357+
Ok(StackPy { inner: popped })
1358+
} else {
1359+
Err(PyIndexError::new_err("popped an empty stack"))
1360+
}
1361+
}
1362+
1363+
fn push(&self, other: Py<PyAny>) -> StackPy {
1364+
StackPy {
1365+
inner: self.inner.push(other),
1366+
}
1367+
}
1368+
}
1369+
1370+
#[pyclass(module = "rpds")]
1371+
struct StackIterator {
1372+
inner: StackSync<Py<PyAny>>,
1373+
}
1374+
1375+
#[pymethods]
1376+
impl StackIterator {
1377+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1378+
slf
1379+
}
1380+
1381+
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
1382+
let first_op = slf.inner.peek()?;
1383+
let first = first_op.clone_ref(slf.py());
1384+
1385+
slf.inner = slf.inner.pop()?;
1386+
1387+
Some(first)
1388+
}
1389+
}
1390+
12411391
#[pyclass(module = "rpds")]
12421392
struct QueueIterator {
12431393
inner: QueueSync<Py<PyAny>>,
@@ -1401,6 +1551,7 @@ fn rpds_py(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
14011551
m.add_class::<HashTrieMapPy>()?;
14021552
m.add_class::<HashTrieSetPy>()?;
14031553
m.add_class::<ListPy>()?;
1554+
m.add_class::<StackPy>()?;
14041555
m.add_class::<QueuePy>()?;
14051556

14061557
PyMapping::register::<HashTrieMapPy>(py)?;

tests/test_stack.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
Modified from the pyrsistent test suite.
3+
4+
Pre-modification, these were MIT licensed, and are copyright:
5+
6+
Copyright (c) 2022 Tobias Gustafsson
7+
8+
Permission is hereby granted, free of charge, to any person
9+
obtaining a copy of this software and associated documentation
10+
files (the "Software"), to deal in the Software without
11+
restriction, including without limitation the rights to use,
12+
copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
copies of the Software, and to permit persons to whom the
14+
Software is furnished to do so, subject to the following
15+
conditions:
16+
17+
The above copyright notice and this permission notice shall be
18+
included in all copies or substantial portions of the Software.
19+
20+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
22+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
24+
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
25+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
26+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
27+
OTHER DEALINGS IN THE SOFTWARE.
28+
"""
29+
30+
import pytest
31+
32+
from rpds import Stack
33+
34+
35+
def test_literalish_works():
36+
assert Stack(1, 2, 3) == Stack([1, 2, 3])
37+
38+
39+
def test_pop_and_peek():
40+
ps = Stack([1, 2])
41+
assert ps.peek() == 2
42+
assert ps.pop().peek() == 1
43+
assert ps.pop().pop() == Stack()
44+
45+
46+
def test_instantiate_large_stack():
47+
assert Stack(range(1000)).peek() == 999
48+
49+
50+
def test_iteration():
51+
assert list(Stack()) == []
52+
assert list(Stack([1, 2, 3]))[::-1] == [1, 2, 3]
53+
54+
55+
def test_push():
56+
assert Stack([1, 2, 3]).push(4) == Stack([1, 2, 3, 4])
57+
58+
59+
def test_push_empty_stack():
60+
assert Stack().push(0) == Stack([0])
61+
62+
63+
def test_truthiness():
64+
assert Stack([1])
65+
assert not Stack()
66+
67+
68+
def test_len():
69+
assert len(Stack([1, 2, 3])) == 3
70+
assert len(Stack()) == 0
71+
72+
73+
def test_peek_illegal_on_empty_stack():
74+
with pytest.raises(IndexError):
75+
Stack().peek()
76+
77+
78+
def test_pop_illegal_on_empty_stack():
79+
with pytest.raises(IndexError):
80+
Stack().pop()
81+
82+
83+
def test_inequality():
84+
assert Stack([1, 2]) != Stack([1, 3])
85+
assert Stack([1, 2]) != Stack([1, 2, 3])
86+
assert Stack() != Stack([1, 2, 3])
87+
88+
89+
def test_repr():
90+
assert str(Stack()) == "Stack([])"
91+
assert str(Stack([1, 2, 3])) in "Stack([1, 2, 3])"
92+
93+
94+
def test_hashing():
95+
o = object()
96+
97+
assert hash(Stack([o, o])) == hash(Stack([o, o]))
98+
assert hash(Stack([o])) == hash(Stack([o]))
99+
assert hash(Stack()) == hash(Stack([]))
100+
assert not (hash(Stack([1, 2])) == hash(Stack([1, 3])))
101+
assert not (hash(Stack([1, 2])) == hash(Stack([2, 1])))
102+
assert not (hash(Stack([o])) == hash(Stack([o, o])))
103+
assert not (hash(Stack([])) == hash(Stack([o])))
104+
105+
assert hash(Stack([1, 2])) != hash(Stack([1, 3]))
106+
assert hash(Stack([1, 2])) != hash(Stack([2, 1]))
107+
assert hash(Stack([o])) != hash(Stack([o, o]))
108+
assert hash(Stack([])) != hash(Stack([o]))
109+
assert not (hash(Stack([o, o])) != hash(Stack([o, o])))
110+
assert not (hash(Stack([o])) != hash(Stack([o])))
111+
assert not (hash(Stack([])) != hash(Stack([])))
112+
113+
114+
def test_sequence():
115+
m = Stack("asdf")
116+
assert m == Stack(["a", "s", "d", "f"])
117+
118+
119+
# Non-pyrsistent-test-suite tests
120+
121+
122+
def test_more_eq():
123+
o = object()
124+
125+
assert Stack([o, o]) == Stack([o, o])
126+
assert Stack([o]) == Stack([o])
127+
assert Stack() == Stack([])
128+
assert not (Stack([1, 2]) == Stack([1, 3]))
129+
assert not (Stack([o]) == Stack([o, o]))
130+
assert not (Stack([]) == Stack([o]))
131+
132+
assert Stack([1, 2]) != Stack([1, 3])
133+
assert Stack([o]) != Stack([o, o])
134+
assert Stack([]) != Stack([o])
135+
assert not (Stack([o, o]) != Stack([o, o]))
136+
assert not (Stack([o]) != Stack([o]))
137+
assert not (Stack() != Stack([]))
138+
139+
140+
def test_rpds_doc():
141+
"""
142+
From the rpds docs.
143+
"""
144+
stack = Stack().push("stack")
145+
assert stack.peek() == "stack"
146+
147+
a_stack = stack.push("a")
148+
assert a_stack.peek() == "a"
149+
150+
stack_popped = a_stack.pop()
151+
assert stack_popped == stack

0 commit comments

Comments
 (0)