Skip to content

Commit

Permalink
fix wrap serializer breaking union serialization in presence of ext…
Browse files Browse the repository at this point in the history
…ra fields (#1530)
  • Loading branch information
davidhewitt authored Nov 12, 2024
1 parent cd0346d commit a3f13c7
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 57 deletions.
4 changes: 4 additions & 0 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ impl<'a> Extra<'a> {
pub fn serialize_infer<'py>(&'py self, value: &'py Bound<'py, PyAny>) -> super::infer::SerializeInfer<'py> {
super::infer::SerializeInfer::new(value, None, None, self)
}

pub(crate) fn model_type_name(&self) -> Option<Bound<'a, PyString>> {
self.model.and_then(|model| model.get_type().name().ok())
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
Expand Down
23 changes: 12 additions & 11 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,15 @@ impl GeneralFieldsSerializer {
};
output_dict.set_item(key, value)?;
} else if field_extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
let type_name = field_extra.model_type_name();
return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Unexpected field `{key}`{for_type_name}",
for_type_name = if let Some(type_name) = type_name {
format!(" for type `{type_name}`")
} else {
String::new()
},
))));
}
}
}
Expand All @@ -212,22 +220,15 @@ impl GeneralFieldsSerializer {
&& self.required_fields > used_req_fields
{
let required_fields = self.required_fields;
let type_name = match extra.model {
Some(model) => model
.get_type()
.qualname()
.ok()
.unwrap_or_else(|| PyString::new_bound(py, "<unknown python object>"))
.to_string(),
None => "<unknown python object>".to_string(),
};
let type_name = extra.model_type_name();
let field_value = match extra.model {
Some(model) => truncate_safe_repr(model, Some(100)),
None => "<unknown python object>".to_string(),
};

Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected."
"Expected {required_fields} fields but got {used_req_fields}{for_type_name} with value `{field_value}` - serialized value may not be as expected.",
for_type_name = if let Some(type_name) = type_name { format!(" for type `{type_name}`") } else { String::new() },
))))
} else {
Ok(output_dict)
Expand Down
15 changes: 15 additions & 0 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ impl FunctionPlainSerializer {
.expect("fallback_serializer unexpectedly none")
.as_ref()
}

fn retry_with_lax_check(&self) -> bool {
self.fallback_serializer
.as_ref()
.map_or(false, |f| f.retry_with_lax_check())
|| self.return_serializer.retry_with_lax_check()
}
}

fn on_error(py: Python, err: PyErr, function_name: &str, extra: &Extra) -> PyResult<()> {
Expand Down Expand Up @@ -271,6 +278,10 @@ macro_rules! function_type_serializer {
fn get_name(&self) -> &str {
&self.name
}

fn retry_with_lax_check(&self) -> bool {
self.retry_with_lax_check()
}
}
};
}
Expand Down Expand Up @@ -409,6 +420,10 @@ impl FunctionWrapSerializer {
fn get_fallback_serializer(&self) -> &CombinedSerializer {
self.serializer.as_ref()
}

fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check() || self.return_serializer.retry_with_lax_check()
}
}

impl_py_gc_traverse!(FunctionWrapSerializer {
Expand Down
139 changes: 93 additions & 46 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,53 +62,36 @@ def __init__(self, c, d):
@pytest.fixture(scope='module')
def model_serializer() -> SchemaSerializer:
return SchemaSerializer(
{
'type': 'union',
'choices': [
{
'type': 'model',
'cls': ModelA,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'bytes'}},
'b': {
'type': 'model-field',
'schema': {
'type': 'float',
'serialization': {
'type': 'format',
'formatting_string': '0.1f',
'when_used': 'unless-none',
},
},
},
},
},
},
{
'type': 'model',
'cls': ModelB,
'schema': {
'type': 'model-fields',
'fields': {
'c': {'type': 'model-field', 'schema': {'type': 'bytes'}},
'd': {
'type': 'model-field',
'schema': {
'type': 'float',
'serialization': {
'type': 'format',
'formatting_string': '0.2f',
'when_used': 'unless-none',
},
},
},
},
},
},
core_schema.union_schema(
[
core_schema.model_schema(
ModelA,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.bytes_schema()),
'b': core_schema.model_field(
core_schema.float_schema(
serialization=core_schema.format_ser_schema('0.1f', when_used='unless-none')
)
),
}
),
),
core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
{
'c': core_schema.model_field(core_schema.bytes_schema()),
'd': core_schema.model_field(
core_schema.float_schema(
serialization=core_schema.format_ser_schema('0.2f', when_used='unless-none')
)
),
}
),
),
],
}
)
)


Expand Down Expand Up @@ -782,6 +765,70 @@ class ModelB:
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}


def test_union_model_wrap_serializer():
def wrap_serializer(value, handler):
return handler(value)

class Data:
pass

class ModelA:
a: Data

class ModelB:
a: Data

model_serializer = SchemaSerializer(
core_schema.union_schema(
[
core_schema.model_schema(
ModelA,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(
core_schema.model_schema(
Data,
core_schema.model_fields_schema({}),
)
),
},
),
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
),
core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(
core_schema.model_schema(
Data,
core_schema.model_fields_schema({}),
)
),
},
),
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
),
],
)
)

input_value = ModelA()
input_value.a = Data()

assert model_serializer.to_python(input_value) == {'a': {}}
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
assert model_serializer.to_json(input_value) == b'{"a":{}}'

# add some additional attribute, should be ignored & not break serialization

input_value.a._a = 'foo'

assert model_serializer.to_python(input_value) == {'a': {}}
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
assert model_serializer.to_json(input_value) == b'{"a":{}}'


class ModelDog:
def __init__(self, type_: Literal['dog']) -> None:
self.type_ = 'dog'
Expand Down

0 comments on commit a3f13c7

Please sign in to comment.