Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ impl<'a, 'py> SerializationState<'a, 'py> {
self.include_exclude.1.as_ref()
}

pub(crate) fn model_type_name(&self) -> Option<Bound<'py, PyString>> {
self.model.as_ref().and_then(|model| model.get_type().name().ok())
}

pub fn serialize_infer<'slf>(
&'slf mut self,
value: &'slf Bound<'py, PyAny>,
Expand Down
51 changes: 25 additions & 26 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ impl GeneralFieldsSerializer {
pub(crate) fn main_to_python<'py>(
&self,
py: Python<'py>,
model: &Bound<'py, PyAny>,
main_iter: impl Iterator<Item = PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Bound<'py, PyDict>> {
Expand Down Expand Up @@ -218,7 +219,7 @@ impl GeneralFieldsSerializer {
return Err(PydanticSerializationUnexpectedValue::new(
Some(format!("Unexpected field `{key}`")),
Some(key_str.to_string()),
state.model_type_name().map(|bound| bound.to_string()),
model_type_name(model),
None,
)
.to_py_err());
Expand All @@ -244,8 +245,8 @@ impl GeneralFieldsSerializer {
Err(PydanticSerializationUnexpectedValue::new(
Some(format!("Expected {required_fields} fields but got {used_req_fields}").to_string()),
state.field_name.as_ref().map(ToString::to_string),
state.model_type_name().map(|bound| bound.to_string()),
state.model.clone().map(Bound::unbind),
model_type_name(model),
Some(model.clone().unbind()),
)
.to_py_err())
} else {
Expand Down Expand Up @@ -353,7 +354,6 @@ impl GeneralFieldsSerializer {
state: &mut SerializationState<'_, 'py>,
) -> PyResult<()> {
if let Some(ref computed_fields) = self.computed_fields {
let state = &mut state.scoped_set(|s| &mut s.model, Some(model.clone()));
computed_fields.to_python(model, output_dict, &self.filter, state)?;
}
Ok(())
Expand All @@ -366,7 +366,6 @@ impl GeneralFieldsSerializer {
state: &mut SerializationState<'_, 'py>,
) -> Result<(), S::Error> {
if let Some(ref computed_fields) = self.computed_fields {
// FIXME: need to match state.model setting above in `add_computed_fields_python`??
computed_fields.serde_serialize::<S>(model, map, &self.filter, state)?;
}
Ok(())
Expand All @@ -390,21 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
) -> PyResult<Py<PyAny>> {
let py = value.py();
let missing_sentinel = get_missing_sentinel_object(py);
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = state.model.clone().unwrap_or_else(|| value.clone());

let model = get_model(state)?;

let Some((main_dict, extra_dict)) = self.extract_dicts(value) else {
state.warn_fallback_py(self.get_name(), value)?;
return infer_to_python(value, state);
};
let output_dict = self.main_to_python(
py,
dict_items(&main_dict),
// FIXME: should also set model for extra serialization?
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
)?;
let output_dict = self.main_to_python(py, &model, dict_items(&main_dict), state)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
Expand Down Expand Up @@ -448,24 +440,15 @@ impl TypeSerializer for GeneralFieldsSerializer {
return infer_serialize(value, serializer, state);
};
let missing_sentinel = get_missing_sentinel_object(value.py());
// If there is already a model registered (from a dataclass, BaseModel)
// then do not touch it
// If there is no model, we (a TypedDict) are the model
let model = state.model.clone().unwrap_or_else(|| value.clone());
let model = get_model(state).map_err(py_err_se_err)?;

let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
};
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_req_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = self.main_serde_serialize(
dict_items(&main_dict),
expected_len,
serializer,
// FIXME: should also set model for extra serialization?
&mut state.scoped_set(|s| &mut s.model, Some(model.clone())),
)?;
let mut map = self.main_serde_serialize(dict_items(&main_dict), expected_len, serializer, state)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
Expand Down Expand Up @@ -507,3 +490,19 @@ fn dict_items<'py>(
let main_items: SmallVec<[_; 16]> = main_dict.iter().collect();
main_items.into_iter().map(Ok)
}

fn get_model<'py>(state: &mut SerializationState<'_, 'py>) -> PyResult<Bound<'py, PyAny>> {
state.model.clone().ok_or_else(|| {
PydanticSerializationUnexpectedValue::new(
Some("No model found for fields serialization".to_string()),
None,
None,
None,
)
.to_py_err()
})
}

fn model_type_name(model: &Bound<'_, PyAny>) -> Option<String> {
model.get_type().name().ok().map(|s| s.to_string())
}
3 changes: 2 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ combined_serializer! {
super::type_serializers::function::FunctionPlainSerializerBuilder;
super::type_serializers::function::FunctionWrapSerializerBuilder;
super::type_serializers::model::ModelFieldsBuilder;
super::type_serializers::typed_dict::TypedDictBuilder;
}
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
// `find_serializer` so they can be used via a `type` str.
Expand Down Expand Up @@ -151,6 +150,7 @@ combined_serializer! {
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
Complex: super::type_serializers::complex::ComplexSerializer;
TypedDict: super::type_serializers::typed_dict::TypedDictSerializer;
}
}

Expand Down Expand Up @@ -356,6 +356,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TypedDict(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,21 @@ impl TypeSerializer for DataclassSerializer {
value: &Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Py<PyAny>> {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if self.allow_value(value, state)? {
let model = value;
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
let py = value.py();
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let output_dict: Bound<PyDict> =
fields_serializer.main_to_python(py, known_dataclass_iter(&self.fields, value), state)?;
fields_serializer.main_to_python(py, model, known_dataclass_iter(&self.fields, model), state)?;

fields_serializer.add_computed_fields_python(value, &output_dict, state)?;
fields_serializer.add_computed_fields_python(model, &output_dict, state)?;
Ok(output_dict.into())
} else {
let inner_value = self.get_inner_value(value)?;
self.serializer.to_python(&inner_value, state)
}
} else {
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
state.warn_fallback_py(self.get_name(), value)?;
infer_to_python(value, state)
}
Expand All @@ -189,8 +189,8 @@ impl TypeSerializer for DataclassSerializer {
serializer: S,
state: &mut SerializationState<'_, 'py>,
) -> Result<S::Ok, S::Error> {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if self.allow_value(value, state).map_err(py_err_se_err)? {
let state = &mut state.scoped_set(|s| &mut s.model, Some(value.clone()));
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
let mut map = fields_serializer.main_serde_serialize(
Expand All @@ -206,7 +206,6 @@ impl TypeSerializer for DataclassSerializer {
self.serializer.serde_serialize(&inner_value, serializer, state)
}
} else {
// FIXME: probably don't want to have state.model set here, should move the scoped_set above?
state.warn_fallback_ser::<S>(self.get_name(), value)?;
infer_serialize(value, serializer, state)
}
Expand Down
54 changes: 51 additions & 3 deletions src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::sync::Arc;

use pyo3::intern;
Expand All @@ -9,14 +10,20 @@ use ahash::AHashMap;
use crate::build_tools::py_schema_err;
use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::shared::TypeSerializer;
use crate::serializers::SerializationState;
use crate::tools::SchemaDict;

use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField};

#[derive(Debug)]
pub struct TypedDictBuilder;
pub struct TypedDictSerializer {
serializer: GeneralFieldsSerializer,
}

impl_py_gc_traverse!(TypedDictSerializer { serializer });

impl BuildSerializer for TypedDictBuilder {
impl BuildSerializer for TypedDictSerializer {
const EXPECTED_TYPE: &'static str = "typed-dict";

fn build(
Expand Down Expand Up @@ -82,10 +89,51 @@ impl BuildSerializer for TypedDictBuilder {
}
}

// FIXME: computed fields do not work for TypedDict, and may never
// see the closed https://github.com/pydantic/pydantic-core/pull/1018
let computed_fields = ComputedFields::new(schema, config, definitions)?;

Ok(Arc::new(
GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into(),
Self {
serializer: GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields),
}
.into(),
))
}
}

impl TypeSerializer for TypedDictSerializer {
fn to_python<'py>(
&self,
value: &Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Py<PyAny>> {
self.serializer
.to_python(value, &mut state.scoped_set(|s| &mut s.model, Some(value.clone())))
}

fn json_key<'a, 'py>(
&self,
key: &'a Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> PyResult<Cow<'a, str>> {
self.invalid_as_json_key(key, state, "typed-dict")
}

fn serde_serialize<'py, S: serde::ser::Serializer>(
&self,
value: &Bound<'py, PyAny>,
serializer: S,
state: &mut SerializationState<'_, 'py>,
) -> Result<S::Ok, S::Error> {
self.serializer.serde_serialize(
value,
serializer,
&mut state.scoped_set(|s| &mut s.model, Some(value.clone())),
)
}

fn get_name(&self) -> &'static str {
"typed-dict"
}
}
32 changes: 32 additions & 0 deletions tests/serializers/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,35 @@ class Model(TypedDict):
)
s = SchemaSerializer(schema, config=core_schema.CoreConfig(serialize_by_alias=config or False))
assert s.to_python(Model(my_field=1), by_alias=runtime) == expected


def test_nested_typed_dict_field_serializers():
class Model(TypedDict):
x: Any

class OuterModel(TypedDict):
model: Model

schema = core_schema.typed_dict_schema(
{
'x': core_schema.typed_dict_field(
core_schema.any_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
# in an incorrect core implementation, self could be OuterModel here
lambda self, v, serializer: f'{list(self.keys())}',
is_field_serializer=True,
schema=core_schema.any_schema(),
)
)
)
}
)
outer_schema = core_schema.typed_dict_schema({'model': core_schema.typed_dict_field(schema)})

s = SchemaSerializer(schema)
assert s.to_python(Model(x=None)) == {'x': "['x']"}

outer_s = SchemaSerializer(outer_schema)
# if the inner field serializer incorrectly receives OuterModel as self, the keys
# will be ['model'] instead of ['x']
assert outer_s.to_python(OuterModel(model=Model(x=None))) == {'model': {'x': "['x']"}}
Loading