Skip to content

Commit

Permalink
Try each option in union serializer before inference (#1398)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Aug 12, 2024
1 parent 863640b commit fd81a75
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
40 changes: 29 additions & 11 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use smallvec::SmallVec;
use std::borrow::Cow;

use crate::build_tools::py_schema_err;
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
SerCheck, TypeSerializer,
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
TypeSerializer,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -78,13 +79,14 @@ impl TypeSerializer for UnionSerializer {
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();

for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
true => (),
false => return Err(err),
false => errors.push(err),
},
}
}
Expand All @@ -95,25 +97,31 @@ impl TypeSerializer for UnionSerializer {
Ok(v) => return Ok(v),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
true => (),
false => return Err(err),
false => errors.push(err),
},
}
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
}

extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
infer_to_python(value, include, exclude, extra)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();

for comb_serializer in &self.choices {
match comb_serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(key.py()) {
true => (),
false => return Err(err),
false => errors.push(err),
},
}
}
Expand All @@ -124,12 +132,16 @@ impl TypeSerializer for UnionSerializer {
Ok(v) => return Ok(v),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(key.py()) {
true => (),
false => return Err(err),
false => errors.push(err),
},
}
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
}

extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
}
Expand All @@ -145,12 +157,14 @@ impl TypeSerializer for UnionSerializer {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();

for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
true => (),
false => return Err(py_err_se_err(err)),
false => errors.push(err),
},
}
}
Expand All @@ -159,14 +173,18 @@ impl TypeSerializer for UnionSerializer {
for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
true => (),
false => return Err(py_err_se_err(err)),
false => errors.push(err),
},
}
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
}

extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
}
Expand Down
2 changes: 2 additions & 0 deletions src/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,5 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach
pystring_fast_new(py, s, ascii_only)
}
}

pub(crate) const UNION_ERR_SMALLVEC_CAPACITY: usize = 4;
4 changes: 2 additions & 2 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult};
use crate::input::{BorrowInput, Input, ValidatedDict};
use crate::lookup_key::LookupKey;
use crate::py_gc::PyGcTraverse;
use crate::tools::SchemaDict;
use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY};

use super::custom_error::CustomError;
use super::literal::LiteralLookup;
Expand Down Expand Up @@ -249,7 +249,7 @@ struct ChoiceLineErrors<'a> {

enum MaybeErrors<'a> {
Custom(&'a CustomError),
Errors(SmallVec<[ChoiceLineErrors<'a>; 4]>),
Errors(SmallVec<[ChoiceLineErrors<'a>; UNION_ERR_SMALLVEC_CAPACITY]>),
}

impl<'a> MaybeErrors<'a> {
Expand Down
24 changes: 24 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,27 @@ def test_union_serializer_picks_exact_type_over_subclass_json(
)
assert s.to_python(input_value, mode='json') == expected_value
assert s.to_json(input_value) == json.dumps(expected_value).encode()


def test_custom_serializer() -> None:
s = SchemaSerializer(
core_schema.union_schema(
[
core_schema.dict_schema(
keys_schema=core_schema.any_schema(),
values_schema=core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x['id']),
),
core_schema.list_schema(
items_schema=core_schema.dict_schema(
keys_schema=core_schema.any_schema(),
values_schema=core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x['id']),
)
),
]
)
)
print(s)
assert s.to_python([{'id': 1}, {'id': 2}]) == [1, 2]
assert s.to_python({'id': 1}) == 1

0 comments on commit fd81a75

Please sign in to comment.