Skip to content

Commit

Permalink
Update pyo3 (#125)
Browse files Browse the repository at this point in the history
* Update pyo3

* fix `eq` test
  • Loading branch information
flying-sheep authored Aug 28, 2024
1 parent 01b3325 commit e141d27
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 129 deletions.
118 changes: 44 additions & 74 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ nom = '7.1.3'
pest = '2.5.4'
pest_derive = '2.5.4'
thiserror = { version = '1.0.38', optional = true }
pyo3 = { version = '0.18.2', optional = true, features = ['multiple-pymethods']}
pyo3 = { version = '0.22.2', optional = true, features = ['multiple-pymethods'] }
paste = "1.0.12"

[build-dependencies]
Expand Down
25 changes: 0 additions & 25 deletions src/impl_help.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
#[macro_export]
macro_rules! impl_richcmp_eq {
($cls: path) => {
#[cfg(feature = "pyo3")]
#[pyo3::pymethods]
impl $cls {
fn __richcmp__(
&self,
other: &Self,
op: pyo3::pyclass::CompareOp,
py: pyo3::Python,
) -> pyo3::PyObject {
use pyo3::pyclass::CompareOp::*;
use pyo3::IntoPy;

return match op {
Eq => (self == other).into_py(py),
Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
};
}
}
};
}

#[macro_export]
macro_rules! impl_bitflags_accessors {
($cls: path, $( $flag: ident ),+ $(,)?) => {
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ pub static ATTR_NAMES: [&str; 6] = [
#[cfg(feature = "pyo3")]
#[pyo3::pymodule]
#[pyo3(name = "xdot_rs")]
pub fn pymodule(py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
pub fn pymodule(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> {
use pyo3::prelude::*;

m.add_class::<ShapeDraw>()?;
m.add_function(pyo3::wrap_pyfunction!(parse_py, m)?)?;
let m_dict = py.import("sys")?.getattr("modules")?;
let m_dict = m.py().import_bound("sys")?.getattr("modules")?;
m.add_wrapped(pyo3::wrap_pymodule!(shapes::pymodule))?;
m_dict.set_item("xdot_rs.shapes", m.getattr("shapes")?)?;
m.add_wrapped(pyo3::wrap_pymodule!(draw::pymodule))?;
Expand Down
23 changes: 14 additions & 9 deletions src/xdot_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ pub use self::draw::Pen;
use self::shapes::Shape;

#[cfg(feature = "pyo3")]
fn try_into_shape(shape: &pyo3::PyAny) -> pyo3::PyResult<Shape> {
fn try_into_shape(shape: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Shape> {
use pyo3::prelude::*;

if let Ok(ell) = shape.extract::<shapes::Ellipse>() {
Ok(ell.into())
} else if let Ok(points) = shape.extract::<shapes::Points>() {
Expand All @@ -28,7 +30,7 @@ fn try_into_shape(shape: &pyo3::PyAny) -> pyo3::PyResult<Shape> {

/// A [Shape] together with a [Pen].
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "pyo3", pyo3::pyclass(module = "xdot_rs"))]
#[cfg_attr(feature = "pyo3", pyo3::pyclass(eq, module = "xdot_rs"))]
pub struct ShapeDraw {
// #[pyo3(get, set)] not possible with cfg_attr
pub pen: Pen,
Expand All @@ -38,7 +40,7 @@ pub struct ShapeDraw {
#[pyo3::pymethods]
impl ShapeDraw {
#[new]
fn new(shape: &pyo3::PyAny, pen: Pen) -> pyo3::PyResult<Self> {
fn new(shape: &pyo3::Bound<'_, pyo3::PyAny>, pen: Pen) -> pyo3::PyResult<Self> {
let shape = try_into_shape(shape)?;
Ok(ShapeDraw { shape, pen })
}
Expand All @@ -60,19 +62,17 @@ impl ShapeDraw {
}
}
#[setter]
fn set_shape(&mut self, shape: &pyo3::PyAny) -> pyo3::PyResult<()> {
fn set_shape(&mut self, shape: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<()> {
self.shape = try_into_shape(shape)?;
Ok(())
}
}
impl_richcmp_eq!(ShapeDraw);

#[cfg(feature = "pyo3")]
#[test]
fn cmp_equal() {
use super::*;
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;

pyo3::prepare_freethreaded_python();

Expand All @@ -84,9 +84,14 @@ fn cmp_equal() {
filled: true,
};
Python::with_gil(|py| {
let a = ShapeDraw::new(ellip.clone().into_py(py).as_ref(py), Pen::default())?;
let b = ShapeDraw::new(ellip.clone().into_py(py).as_ref(py), Pen::default())?;
assert!(a.__richcmp__(&b, CompareOp::Eq, py).extract::<bool>(py)?);
let a = ShapeDraw::new(ellip.clone().into_py(py).bind(py), Pen::default())?;
let b = ShapeDraw::new(ellip.clone().into_py(py).bind(py), Pen::default())?;
assert!(a
.into_py(py)
.bind(py)
.getattr("__eq__")?
.call1((b,))?
.extract::<bool>()?);
Ok::<(), PyErr>(())
})
.unwrap();
Expand Down
Loading

0 comments on commit e141d27

Please sign in to comment.