Skip to content

Commit

Permalink
Fix serializing complex values in enums (#1524)
Browse files Browse the repository at this point in the history
  • Loading branch information
changhc authored Nov 4, 2024
1 parent 184e7be commit 2ee8fa8
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 41 deletions.
28 changes: 13 additions & 15 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,9 @@ pub(crate) fn infer_to_python_known(
PyList::new_bound(py, items).into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
let v = value.downcast::<PyComplex>()?;
let complex_str = type_serializers::complex::complex_to_str(v);
complex_str.into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
Expand Down Expand Up @@ -286,11 +284,9 @@ pub(crate) fn infer_to_python_known(
iter.into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
let v = value.downcast::<PyComplex>()?;
let complex_str = type_serializers::complex::complex_to_str(v);
complex_str.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
Expand Down Expand Up @@ -422,10 +418,8 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
ObType::Bool => serialize!(bool),
ObType::Complex => {
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry(&"real", &v.real())?;
map.serialize_entry(&"imag", &v.imag())?;
map.end()
let complex_str = type_serializers::complex::complex_to_str(v);
Ok(serializer.collect_str::<String>(&complex_str)?)
}
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>().map_err(py_err_se_err)?;
Expand Down Expand Up @@ -672,7 +666,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
Ok(Cow::Owned(key_build.finish()))
}
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
}
ObType::Dataclass | ObType::PydanticSerializable => {
Expand All @@ -689,6 +683,10 @@ pub(crate) fn infer_json_key_known<'a>(
// FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
}
ObType::Complex => {
let v = key.downcast::<PyComplex>()?;
Ok(type_serializers::complex::complex_to_str(v).into())
}
ObType::Pattern => Ok(Cow::Owned(
key.getattr(intern!(key.py(), "pattern"))?
.str()?
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ impl ObTypeLookup {
ObType::Url
} else if ob_type == self.multi_host_url {
ObType::MultiHostUrl
} else if ob_type == self.complex {
ObType::Complex
} else if ob_type == self.uuid_object.as_ptr() as usize {
ObType::Uuid
} else if is_pydantic_serializable(op_value) {
Expand Down
45 changes: 19 additions & 26 deletions src/serializers/type_serializers/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,10 @@ impl TypeSerializer for ComplexSerializer {
) -> PyResult<PyObject> {
let py = value.py();
match value.downcast::<PyComplex>() {
Ok(py_complex) => match extra.mode {
SerMode::Json => {
let re = py_complex.real();
let im = py_complex.imag();
let mut s = format!("{im}j");
if re != 0.0 {
let mut sign = "";
if im >= 0.0 {
sign = "+";
}
s = format!("{re}{sign}{s}");
}
Ok(s.into_py(py))
}
_ => Ok(value.into_py(py)),
},
Ok(py_complex) => Ok(match extra.mode {
SerMode::Json => complex_to_str(py_complex).into_py(py),
_ => value.into_py(py),
}),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
infer_to_python(value, include, exclude, extra)
Expand All @@ -70,16 +58,7 @@ impl TypeSerializer for ComplexSerializer {
) -> Result<S::Ok, S::Error> {
match value.downcast::<PyComplex>() {
Ok(py_complex) => {
let re = py_complex.real();
let im = py_complex.imag();
let mut s = format!("{im}j");
if re != 0.0 {
let mut sign = "";
if im >= 0.0 {
sign = "+";
}
s = format!("{re}{sign}{s}");
}
let s = complex_to_str(py_complex);
Ok(serializer.collect_str::<String>(&s)?)
}
Err(_) => {
Expand All @@ -93,3 +72,17 @@ impl TypeSerializer for ComplexSerializer {
"complex"
}
}

pub fn complex_to_str(py_complex: &Bound<'_, PyComplex>) -> String {
let re = py_complex.real();
let im = py_complex.imag();
let mut s = format!("{im}j");
if re != 0.0 {
let mut sign = "";
if im >= 0.0 {
sign = "+";
}
s = format!("{re}{sign}{s}");
}
s
}
28 changes: 28 additions & 0 deletions tests/serializers/test_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import Enum

from pydantic_core import SchemaSerializer, core_schema


# serializing enum calls methods in serializers::infer
def test_infer_to_python():
class MyEnum(Enum):
complex_ = complex(1, 2)

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_python(MyEnum.complex_, mode='json') == '1+2j'


def test_infer_serialize():
class MyEnum(Enum):
complex_ = complex(1, 2)

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_json(MyEnum.complex_) == b'"1+2j"'


def test_infer_json_key():
class MyEnum(Enum):
complex_ = {complex(1, 2): 1}

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_json(MyEnum.complex_) == b'{"1+2j":1}'

0 comments on commit 2ee8fa8

Please sign in to comment.