diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ebb355..6ea1439 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## Unreleased + + - Add `pyobject` module for serialize and deserialize `PyObject` + - Add `PyObjectVisitor` that can be used to implement deserialize + - Add feature `serde_with` + ## 0.21.1 - 2024-04-02 - Fix compile error when using PyO3 `abi3` feature targeting a minimum version below 3.10 diff --git a/Cargo.toml b/Cargo.toml index 79b8a91..71e4d14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,15 @@ homepage = "https://github.com/davidhewitt/pythonize" repository = "https://github.com/davidhewitt/pythonize" documentation = "https://docs.rs/crate/pythonize/" +[features] +default = ["serde_with"] +serde_with = ["serde-transcode"] [dependencies] serde = { version = "1.0", default-features = false, features = ["std"] } pyo3 = { version = "0.21.0", default-features = false } +serde-transcode = { version = "1.0", default-features = false, optional = true } + [dev-dependencies] serde = { version = "1.0", default-features = false, features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs index ea04c22..710cbcb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,7 @@ mod de; mod error; mod ser; +mod visitor; #[allow(deprecated)] pub use crate::de::depythonize; @@ -48,3 +49,41 @@ pub use crate::ser::{ pythonize, pythonize_custom, PythonizeDefault, PythonizeDictType, PythonizeListType, PythonizeTypes, Pythonizer, }; +pub use crate::visitor::PyObjectVisitor; + +#[cfg(feature = "serde_with")] +/// This module provides a Serde `Serialize` and `Deserialize` implementation for `PyObject`. +/// +/// ```rust +/// use pyo3::PyObject; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct Foo { +/// #[serde(with = "pythonize::pyobject")] +/// #[serde(flatten)] +/// inner: PyObject, +/// } +/// ``` +pub mod pyobject { + use pyo3::{PyObject, Python}; + use serde::Serializer; + + pub fn serialize(obj: &PyObject, serializer: S) -> Result + where + S: Serializer, + { + Python::with_gil(|py| { + let mut deserializer = + crate::Depythonizer::from_object_bound(obj.clone().into_bound(py)); + serde_transcode::transcode(&mut deserializer, serializer) + }) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Python::with_gil(|py| deserializer.deserialize_any(crate::PyObjectVisitor::new(py))) + } +} diff --git a/src/visitor.rs b/src/visitor.rs new file mode 100644 index 0000000..8469bd9 --- /dev/null +++ b/src/visitor.rs @@ -0,0 +1,183 @@ +use pyo3::{ + types::{PyDict, PyDictMethods, PyList, PyListMethods}, + PyObject, Python, ToPyObject, +}; +use serde::{ + de::{MapAccess, SeqAccess, VariantAccess, Visitor}, + Deserialize, +}; + +pub struct PyObjectVisitor<'py> { + py: Python<'py>, +} + +impl<'py> PyObjectVisitor<'py> { + pub fn new(py: Python) -> PyObjectVisitor<'_> { + PyObjectVisitor { py } + } +} + +struct Wrapper { + inner: PyObject, +} + +impl<'de> Deserialize<'de> for Wrapper { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Python::with_gil(|py| { + deserializer + .deserialize_any(PyObjectVisitor { py }) + .map(|inner| Wrapper { inner }) + }) + } +} + +impl<'de, 'py> Visitor<'de> for PyObjectVisitor<'py> { + type Value = PyObject; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("any PyObject") + } + + fn visit_bool(self, value: bool) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_i64(self, value: i64) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_i128(self, value: i128) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_u64(self, value: u64) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_u128(self, value: u128) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_f64(self, value: f64) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_str(self, value: &str) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_string(self, value: String) -> Result { + Ok(value.to_object(self.py)) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + Ok(v.to_object(self.py)) + } + + fn visit_none(self) -> Result { + Ok(self.py.None()) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let wrapper: Wrapper = Deserialize::deserialize(deserializer)?; + Ok(wrapper.inner) + } + + fn visit_unit(self) -> Result { + self.visit_none() + } + + fn visit_seq(self, mut visitor: V) -> Result + where + V: SeqAccess<'de>, + { + let list = PyList::empty_bound(self.py); + + while let Some(wrapper) = visitor.next_element::()? { + list.append(wrapper.inner).unwrap(); + } + + Ok(list.to_object(self.py)) + } + + fn visit_map(self, mut visitor: V) -> Result + where + V: MapAccess<'de>, + { + let dict = PyDict::new_bound(self.py); + + while let Some((key, value)) = visitor.next_entry::()? { + dict.set_item(key.inner, value.inner).unwrap(); + } + + Ok(dict.to_object(self.py)) + } + + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + let dict = PyDict::new_bound(self.py); + let (value, variant): (Wrapper, _) = data.variant()?; + let variant: Wrapper = variant.newtype_variant()?; + dict.set_item(variant.inner, value.inner).unwrap(); + + Ok(dict.to_object(self.py)) + } +} + +#[cfg(test)] +mod tests { + use pyo3::{ + types::{PyAnyMethods, PyStringMethods}, + PyObject, Python, + }; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + #[derive(Serialize, Deserialize)] + struct Foo { + #[serde(with = "crate::pyobject")] + #[serde(flatten)] + inner: PyObject, + } + + #[test] + fn simple_test() { + let value = json!({ + "code": 200, + "success": true, + "payload": { + "features": [ + "serde", + "json" + ], + "homepage": null + } + }); + + Python::with_gil(|py| { + let foo = Foo { + inner: crate::pythonize(py, &value).unwrap(), + }; + let serialized = serde_json::to_string(&foo).unwrap(); + let deserialized: Foo = serde_json::from_str(&serialized).unwrap(); + assert_eq!( + deserialized + .inner + .bind(py) + .repr() + .unwrap() + .to_str() + .unwrap(), + foo.inner.bind(py).repr().unwrap().to_str().unwrap() + ); + }); + } +}