Skip to content

Commit

Permalink
fix(union_serializer): do not raise warnings in nested unions
Browse files Browse the repository at this point in the history
In case unions of unions are used, this will bubble-up the errors rather
than warning immediately. If no solution is found among all serializers
by the top-level union, it will warn as before.

Signed-off-by: Luka Peschke <mail@lukapeschke.com>
  • Loading branch information
lukapeschke committed Oct 31, 2024
1 parent 9217019 commit 41d6b25
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 7 deletions.
96 changes: 89 additions & 7 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,61 @@ impl UnionSerializer {
}
}
}

fn _to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> ToPythonExtractorResult {
to_python_extractor(value, include, exclude, extra, &self.choices)
}
}

impl_py_gc_traverse!(UnionSerializer { choices });

#[derive(Debug)]
enum ToPythonExtractorResult {
Success(PyObject),
Errors(SmallVec<[PyErr; SMALL_UNION_THRESHOLD]>),
}

fn to_python_extractor(
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
choices: &[CombinedSerializer],
) -> ToPythonExtractorResult {
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer {
CombinedSerializer::Union(union_serializer) => {
match union_serializer._to_python(value, include, exclude, extra) {
ToPythonExtractorResult::Errors(errs) => errors.extend(errs),
ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success),
}
}
CombinedSerializer::TaggedUnion(tagged_union_serializer) => {
match tagged_union_serializer._to_python(value, include, exclude, extra) {
ToPythonExtractorResult::Errors(errs) => errors.extend(errs),
ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success),
}
}
_ => {
match comb_serializer.to_python(value, include, exclude, extra) {
Ok(v) => return ToPythonExtractorResult::Success(v),
Err(err) => errors.push(err),
};
}
}
}

ToPythonExtractorResult::Errors(errors)
}

fn to_python(
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
Expand All @@ -80,14 +131,13 @@ fn to_python(
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => errors.push(err),
}
}
let res = to_python_extractor(value, include, exclude, &new_extra, choices);

let errors = match res {
ToPythonExtractorResult::Success(obj) => return Ok(obj),
ToPythonExtractorResult::Errors(errs) => errs,
};

if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
Expand Down Expand Up @@ -392,6 +442,38 @@ impl TypeSerializer for TaggedUnionSerializer {
}

impl TaggedUnionSerializer {
fn _to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> ToPythonExtractorResult {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return ToPythonExtractorResult::Success(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
return ToPythonExtractorResult::Success(v);
}
}
}
}
}
}

to_python_extractor(value, include, exclude, extra, &self.choices)
}

fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
Expand Down
66 changes: 66 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,69 @@ class ModelB:
model_b = ModelB(field=1)
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}


class ModelDog:
def __init__(self, type_: Literal['dog']) -> None:
self.type_ = 'dog'


class ModelCat:
def __init__(self, type_: Literal['cat']) -> None:
self.type_ = 'cat'


def test_union_of_unions_of_models() -> None:
s = SchemaSerializer(
core_schema.union_schema(
[
core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelA,
schema=core_schema.model_fields_schema(
fields={
'a': core_schema.model_field(core_schema.str_schema()),
'b': core_schema.model_field(core_schema.str_schema()),
},
),
),
core_schema.model_schema(
cls=ModelB,
schema=core_schema.model_fields_schema(
fields={
'c': core_schema.model_field(core_schema.str_schema()),
'd': core_schema.model_field(core_schema.str_schema()),
},
),
),
]
),
core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
]
),
]
)
)

assert s.to_python(ModelA(a='a', b='b'), warnings='error') == {'a': 'a', 'b': 'b'}
assert s.to_python(ModelB(c='c', d='d'), warnings='error') == {'c': 'c', 'd': 'd'}
assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'}
assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}

0 comments on commit 41d6b25

Please sign in to comment.