Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(union_serializer): do not raise warnings in nested unions #1513

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
Expand Down Expand Up @@ -89,7 +90,8 @@ fn to_python(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand All @@ -98,8 +100,17 @@ fn to_python(
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if extra.check != SerCheck::None here. That would tell us if we're inside a nested. union serialization. If so, we should return a PydanticSerializationUnexpectedValue error here containing the errors.

That way, the warning and the inference fallback below this point will only happen at the top level union.

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_to_python(value, include, exclude, extra)
Expand All @@ -122,7 +133,8 @@ fn json_key<'a>(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
Expand All @@ -131,10 +143,18 @@ fn json_key<'a>(
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_json_key(key, extra)
}

Expand All @@ -160,7 +180,8 @@ fn serde_serialize<S: serde::ser::Serializer>(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand All @@ -169,8 +190,14 @@ fn serde_serialize<S: serde::ser::Serializer>(
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
} else {
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
// will have to be returned here
Comment on lines +198 to +200
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this empty for now, as we can't return a PydanticSerializationUnexpectedValue here since the function is supposed to return S::Error., and refactoring this would introduce extra complexity

}

infer_serialize(value, serializer, include, exclude, extra)
Expand Down
222 changes: 222 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import json
import uuid
Expand Down Expand Up @@ -778,3 +780,223 @@ class ModelB:
model_b = ModelB(field=1)
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}


class ModelDog:
def __init__(self, type_: Literal['dog']) -> None:
self.type_ = 'dog'


class ModelCat:
def __init__(self, type_: Literal['cat']) -> None:
self.type_ = 'cat'


class ModelAlien:
def __init__(self, type_: Literal['alien']) -> None:
self.type_ = 'alien'


@pytest.fixture
def model_a_b_union_schema() -> core_schema.UnionSchema:
return core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelA,
schema=core_schema.model_fields_schema(
fields={
'a': core_schema.model_field(core_schema.str_schema()),
'b': core_schema.model_field(core_schema.str_schema()),
},
),
),
core_schema.model_schema(
cls=ModelB,
schema=core_schema.model_fields_schema(
fields={
'c': core_schema.model_field(core_schema.str_schema()),
'd': core_schema.model_field(core_schema.str_schema()),
},
),
),
]
)


@pytest.fixture
def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
return core_schema.union_schema(
[
model_a_b_union_schema,
core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
]
),
]
)


@pytest.mark.parametrize(
'input,expected',
[
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
(ModelCat(type_='cat'), {'type_': 'cat'}),
(ModelDog(type_='dog'), {'type_': 'dog'}),
],
)
def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None:
s = SchemaSerializer(union_of_unions_schema)
assert s.to_python(input, warnings='error') == expected


def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None:
s = SchemaSerializer(union_of_unions_schema)
# All warnings should be available
messages = [
'Expected `ModelA` but got `ModelAlien`',
'Expected `ModelB` but got `ModelAlien`',
'Expected `ModelCat` but got `ModelAlien`',
'Expected `ModelDog` but got `ModelAlien`',
]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
s.to_python(ModelAlien(type_='alien'))
for m in messages:
assert m in str(w[0].message)


@pytest.fixture
def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
return core_schema.union_schema(
[
model_a_b_union_schema,
core_schema.tagged_union_schema(
discriminator='type_',
choices={
'cat': core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
'dog': core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
},
),
]
)


@pytest.mark.parametrize(
'input,expected',
[
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
(ModelCat(type_='cat'), {'type_': 'cat'}),
(ModelDog(type_='dog'), {'type_': 'dog'}),
],
)
def test_union_of_unions_of_models_with_tagged_union(
tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any
) -> None:
s = SchemaSerializer(tagged_union_of_unions_schema)
assert s.to_python(input, warnings='error') == expected


def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
tagged_union_of_unions_schema: core_schema.UnionSchema,
) -> None:
s = SchemaSerializer(tagged_union_of_unions_schema)
# All warnings should be available
messages = [
'Expected `ModelA` but got `ModelAlien`',
'Expected `ModelB` but got `ModelAlien`',
'Expected `ModelCat` but got `ModelAlien`',
'Expected `ModelDog` but got `ModelAlien`',
]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
s.to_python(ModelAlien(type_='alien'))
for m in messages:
assert m in str(w[0].message)


@pytest.mark.parametrize(
'input,expected',
[
({True: '1'}, b'{"true":"1"}'),
({1: '1'}, b'{"1":"1"}'),
({2.3: '1'}, b'{"2.3":"1"}'),
({'a': 'b'}, b'{"a":"b"}'),
],
)
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
input: dict[bool | int | float | str, str], expected: bytes
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(
keys_schema=core_schema.union_schema(
[
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
]
),
values_schema=core_schema.str_schema(),
)
)

assert s.to_json(input, warnings='error') == expected


@pytest.mark.parametrize(
'input,expected',
[
({'key': True}, b'{"key":true}'),
({'key': 1}, b'{"key":1}'),
({'key': 2.3}, b'{"key":2.3}'),
({'key': 'a'}, b'{"key":"a"}'),
],
)
def test_union_of_unions_of_models_with_tagged_union_json_serialization(
input: dict[str, bool | int | float | str], expected: bytes
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.union_schema(
[
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
]
),
)
)

assert s.to_json(input, warnings='error') == expected