Skip to content

Commit eb48cad

Browse files
davidhewittViicos
andauthored
respect field_serializer when using serialize_as_any=True (#1835)
Co-authored-by: Victorien <65306057+Viicos@users.noreply.github.com>
1 parent 0425954 commit eb48cad

File tree

4 files changed

+167
-46
lines changed

4 files changed

+167
-46
lines changed

src/serializers/fields.rs

Lines changed: 90 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use smallvec::SmallVec;
1111

1212
use crate::common::missing_sentinel::get_missing_sentinel_object;
1313
use crate::serializers::extra::SerCheck;
14+
use crate::serializers::type_serializers::any::AnySerializer;
15+
use crate::serializers::type_serializers::function::{FunctionPlainSerializer, FunctionWrapSerializer};
1416
use crate::PydanticSerializationUnexpectedValue;
1517

1618
use super::computed_fields::ComputedFields;
@@ -190,31 +192,26 @@ impl GeneralFieldsSerializer {
190192
..extra
191193
};
192194
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
193-
if let Some(field) = op_field {
194-
if let Some(ref serializer) = field.serializer {
195-
if exclude_default(&value, &field_extra, serializer)? {
196-
continue;
197-
}
198-
if serialization_exclude_if(field.serialization_exclude_if.as_ref(), &value)? {
199-
continue;
200-
}
201-
let value =
202-
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
203-
let output_key = field.get_key_py(output_dict.py(), &field_extra);
204-
output_dict.set_item(output_key, value)?;
205-
}
195+
let (key, serializer) = if let Some(field) = op_field {
196+
let serializer = Self::prepare_value(&value, field, &field_extra)?;
206197

207198
if field.required {
208199
used_req_fields += 1;
209200
}
210-
} else if self.mode == FieldsMode::TypedDictAllow {
211-
let value = match &self.extra_serializer {
212-
Some(serializer) => {
213-
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
214-
}
215-
_ => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
201+
202+
let Some(serializer) = serializer else {
203+
continue;
216204
};
217-
output_dict.set_item(key, value)?;
205+
206+
(field.get_key_py(output_dict.py(), &field_extra), serializer)
207+
} else if self.mode == FieldsMode::TypedDictAllow {
208+
let serializer = self
209+
.extra_serializer
210+
.as_ref()
211+
// If using `serialize_as_any`, extras are always inferred
212+
.filter(|_| !extra.serialize_as_any)
213+
.unwrap_or_else(|| AnySerializer::get());
214+
(&key, serializer)
218215
} else if field_extra.check == SerCheck::Strict {
219216
return Err(PydanticSerializationUnexpectedValue::new(
220217
Some(format!("Unexpected field `{key}`")),
@@ -223,7 +220,18 @@ impl GeneralFieldsSerializer {
223220
None,
224221
)
225222
.to_py_err());
226-
}
223+
} else {
224+
continue;
225+
};
226+
227+
// Use `no_infer` here because the `serialize_as_any` logic has been handled in `prepare_value`
228+
let value = serializer.to_python_no_infer(
229+
&value,
230+
next_include.as_ref(),
231+
next_exclude.as_ref(),
232+
&field_extra,
233+
)?;
234+
output_dict.set_item(key, value)?;
227235
}
228236
}
229237

@@ -257,7 +265,7 @@ impl GeneralFieldsSerializer {
257265
extra: Extra,
258266
) -> Result<S::SerializeMap, S::Error> {
259267
// NOTE! As above, we maintain the order of the input dict assuming that's right
260-
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
268+
// we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used
261269
let mut map = serializer.serialize_map(Some(expected_len))?;
262270

263271
for result in main_iter {
@@ -278,26 +286,23 @@ impl GeneralFieldsSerializer {
278286
let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
279287
if let Some((next_include, next_exclude)) = filter {
280288
if let Some(field) = self.fields.get(key_str) {
281-
if let Some(ref serializer) = field.serializer {
282-
if exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
283-
continue;
284-
}
285-
if serialization_exclude_if(field.serialization_exclude_if.as_ref(), &value)
286-
.map_err(py_err_se_err)?
287-
{
288-
continue;
289-
}
290-
let s = PydanticSerializer::new(
291-
&value,
292-
serializer,
293-
next_include.as_ref(),
294-
next_exclude.as_ref(),
295-
&field_extra,
296-
);
297-
let output_key = field.get_key_json(key_str, &field_extra);
298-
map.serialize_entry(&output_key, &s)?;
299-
}
289+
let Some(serializer) = Self::prepare_value(&value, field, &field_extra).map_err(py_err_se_err)?
290+
else {
291+
continue;
292+
};
293+
294+
// Use `no_infer` here because the `serialize_as_any` logic has been handled in `prepare_value`
295+
let s = PydanticSerializer::new_no_infer(
296+
&value,
297+
serializer,
298+
next_include.as_ref(),
299+
next_exclude.as_ref(),
300+
&field_extra,
301+
);
302+
let output_key = field.get_key_json(key_str, &field_extra);
303+
map.serialize_entry(&output_key, &s)?;
300304
} else if self.mode == FieldsMode::TypedDictAllow {
305+
// FIXME: why is `extra_serializer` not used here when `serialize_as_any` is not set?
301306
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;
302307
let s = SerializeInfer::new(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
303308
map.serialize_entry(&output_key, &s)?;
@@ -308,6 +313,49 @@ impl GeneralFieldsSerializer {
308313
Ok(map)
309314
}
310315

316+
/// Gets the serializer to use for a field, applying `serialize_as_any` logic and applying any
317+
/// field-level exclusions
318+
fn prepare_value<'s>(
319+
value: &Bound<'_, PyAny>,
320+
field: &'s SerField,
321+
field_extra: &Extra,
322+
) -> PyResult<Option<&'s Arc<CombinedSerializer>>> {
323+
let Some(serializer) = field.serializer.as_ref() else {
324+
// field excluded at schema level
325+
return Ok(None);
326+
};
327+
328+
if exclude_default(value, field_extra, serializer)? {
329+
return Ok(None);
330+
}
331+
332+
// FIXME: should `exclude_if` be applied to extra fields too?
333+
if serialization_exclude_if(field.serialization_exclude_if.as_ref(), value)? {
334+
return Ok(None);
335+
}
336+
337+
Ok(Some(
338+
if field_extra.serialize_as_any &&
339+
// if serialize_as_any is set, we ensure that field serializers are
340+
// still used, because this would match the `SerializeAsAny` annotation
341+
// on a field
342+
!matches!(
343+
serializer.as_ref(),
344+
CombinedSerializer::Function(FunctionPlainSerializer {
345+
is_field_serializer: true,
346+
..
347+
}) | CombinedSerializer::FunctionWrap(FunctionWrapSerializer {
348+
is_field_serializer: true,
349+
..
350+
})
351+
) {
352+
AnySerializer::get()
353+
} else {
354+
serializer
355+
},
356+
))
357+
}
358+
311359
pub(crate) fn add_computed_fields_python(
312360
&self,
313361
model: Option<&Bound<'_, PyAny>>,
@@ -425,7 +473,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
425473
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
426474
};
427475
// NOTE! As above, we maintain the order of the input dict assuming that's right
428-
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
476+
// we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used
429477
let mut map = self.main_serde_serialize(
430478
dict_items(&main_dict),
431479
expected_len,

src/serializers/shared.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::build_tools::py_schema_error_type;
1818
use crate::definitions::DefinitionsBuilder;
1919
use crate::py_gc::PyGcTraverse;
2020
use crate::serializers::ser::PythonSerializer;
21+
use crate::serializers::type_serializers::any::AnySerializer;
2122
use crate::tools::{py_err, SchemaDict};
2223

2324
use super::errors::se_err_py_err;
@@ -418,6 +419,27 @@ impl<'py> PydanticSerializer<'py> {
418419
include: Option<&'py Bound<'py, PyAny>>,
419420
exclude: Option<&'py Bound<'py, PyAny>>,
420421
extra: &'py Extra<'py>,
422+
) -> Self {
423+
Self {
424+
value,
425+
serializer: if extra.serialize_as_any {
426+
AnySerializer::get()
427+
} else {
428+
serializer
429+
},
430+
include,
431+
exclude,
432+
extra,
433+
}
434+
}
435+
436+
/// Same as above but will not fall back to type inference when `serialize_as_any` is set
437+
pub(crate) fn new_no_infer(
438+
value: &'py Bound<'py, PyAny>,
439+
serializer: &'py CombinedSerializer,
440+
include: Option<&'py Bound<'py, PyAny>>,
441+
exclude: Option<&'py Bound<'py, PyAny>>,
442+
extra: &'py Extra<'py>,
421443
) -> Self {
422444
Self {
423445
value,
@@ -431,8 +453,9 @@ impl<'py> PydanticSerializer<'py> {
431453

432454
impl Serialize for PydanticSerializer<'_> {
433455
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
456+
// inference is handled in the constructor
434457
self.serializer
435-
.serde_serialize(self.value, serializer, self.include, self.exclude, self.extra)
458+
.serde_serialize_no_infer(self.value, serializer, self.include, self.exclude, self.extra)
436459
}
437460
}
438461

src/serializers/type_serializers/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ pub struct FunctionPlainSerializer {
8181
// fallback serializer - used when when_used decides that this serializer should not be used
8282
fallback_serializer: Option<Arc<CombinedSerializer>>,
8383
when_used: WhenUsed,
84-
is_field_serializer: bool,
84+
pub(crate) is_field_serializer: bool,
8585
info_arg: bool,
8686
}
8787

@@ -334,7 +334,7 @@ pub struct FunctionWrapSerializer {
334334
function_name: String,
335335
return_serializer: Arc<CombinedSerializer>,
336336
when_used: WhenUsed,
337-
is_field_serializer: bool,
337+
pub(crate) is_field_serializer: bool,
338338
info_arg: bool,
339339
}
340340

tests/serializers/test_serialize_as_any.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Callable, Optional
33

4+
import pytest
45
from typing_extensions import TypedDict
56

67
from pydantic_core import SchemaSerializer, SchemaValidator, core_schema
@@ -370,3 +371,52 @@ def a_model_serializer(self, handler, info):
370371
assert MyModel.__pydantic_serializer__.to_python(instance, serialize_as_any=True) == {
371372
'a_field_wrapped': {'an_inner_field': 1},
372373
}
374+
375+
376+
@pytest.fixture(params=['model', 'dataclass'])
377+
def container_schema_builder(
378+
request: pytest.FixtureRequest,
379+
) -> Callable[[dict[str, core_schema.CoreSchema]], core_schema.CoreSchema]:
380+
if request.param == 'model':
381+
return lambda fields: core_schema.model_schema(
382+
cls=type('Test', (), {}),
383+
schema=core_schema.model_fields_schema(
384+
fields={k: core_schema.model_field(schema=v) for k, v in fields.items()},
385+
),
386+
)
387+
elif request.param == 'dataclass':
388+
return lambda fields: core_schema.dataclass_schema(
389+
cls=dataclass(type('Test', (), {})),
390+
schema=core_schema.dataclass_args_schema(
391+
'Test',
392+
fields=[core_schema.dataclass_field(name=k, schema=v) for k, v in fields.items()],
393+
),
394+
fields=[k for k in fields.keys()],
395+
)
396+
else:
397+
raise ValueError(f'Unknown container type {request.param}')
398+
399+
400+
def test_serialize_as_any_with_field_serializer(container_schema_builder) -> None:
401+
# https://github.com/pydantic/pydantic/issues/12379
402+
403+
schema = container_schema_builder(
404+
{
405+
'value': core_schema.int_schema(
406+
serialization=core_schema.plain_serializer_function_ser_schema(
407+
lambda model, v: v * 2, is_field_serializer=True
408+
)
409+
)
410+
}
411+
)
412+
413+
v = SchemaValidator(schema).validate_python({'value': 123})
414+
cls = type(v)
415+
s = SchemaSerializer(schema)
416+
# necessary to ensure that type inference will pick up the serializer
417+
cls.__pydantic_serializer__ = s
418+
419+
assert s.to_python(v, serialize_as_any=False) == {'value': 246}
420+
assert s.to_python(v, serialize_as_any=True) == {'value': 246}
421+
assert s.to_json(v, serialize_as_any=False) == b'{"value":246}'
422+
assert s.to_json(v, serialize_as_any=True) == b'{"value":246}'

0 commit comments

Comments
 (0)