diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 24544a971..dcd592a90 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -65,10 +65,61 @@ impl UnionSerializer { } } } + + fn _to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> ToPythonExtractorResult { + to_python_extractor(value, include, exclude, extra, &self.choices) + } } impl_py_gc_traverse!(UnionSerializer { choices }); +#[derive(Debug)] +enum ToPythonExtractorResult { + Success(PyObject), + Errors(SmallVec<[PyErr; SMALL_UNION_THRESHOLD]>), +} + +fn to_python_extractor( + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], +) -> ToPythonExtractorResult { + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + + for comb_serializer in choices { + match comb_serializer { + CombinedSerializer::Union(union_serializer) => { + match union_serializer._to_python(value, include, exclude, extra) { + ToPythonExtractorResult::Errors(errs) => errors.extend(errs), + ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success), + } + } + CombinedSerializer::TaggedUnion(tagged_union_serializer) => { + match tagged_union_serializer._to_python(value, include, exclude, extra) { + ToPythonExtractorResult::Errors(errs) => errors.extend(errs), + ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success), + } + } + _ => { + match comb_serializer.to_python(value, include, exclude, extra) { + Ok(v) => return ToPythonExtractorResult::Success(v), + Err(err) => errors.push(err), + }; + } + } + } + + ToPythonExtractorResult::Errors(errors) +} + fn to_python( value: &Bound<'_, PyAny>, include: Option<&Bound<'_, PyAny>>, @@ -80,14 +131,13 @@ fn to_python( // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - for comb_serializer in choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => errors.push(err), - } - } + let res = to_python_extractor(value, include, exclude, &new_extra, choices); + + let errors = match res { + ToPythonExtractorResult::Success(obj) => return Ok(obj), + ToPythonExtractorResult::Errors(errs) => errs, + }; if retry_with_lax_check { new_extra.check = SerCheck::Lax; @@ -392,6 +442,38 @@ impl TypeSerializer for TaggedUnionSerializer { } impl TaggedUnionSerializer { + fn _to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> ToPythonExtractorResult { + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + if let Some(tag) = self.get_discriminator_value(value, extra) { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let serializer = &self.choices[serializer_index]; + + match serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return ToPythonExtractorResult::Success(v), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) { + return ToPythonExtractorResult::Success(v); + } + } + } + } + } + } + + to_python_extractor(value, include, exclude, extra, &self.choices) + } + fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option> { let py = value.py(); let discriminator_value = match &self.discriminator { diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 8b6d6f128..eb2e8d4fd 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -778,3 +778,69 @@ class ModelB: model_b = ModelB(field=1) assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'} assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'} + + +class ModelDog: + def __init__(self, type_: Literal['dog']) -> None: + self.type_ = 'dog' + + +class ModelCat: + def __init__(self, type_: Literal['cat']) -> None: + self.type_ = 'cat' + + +def test_union_of_unions_of_models() -> None: + s = SchemaSerializer( + core_schema.union_schema( + [ + core_schema.union_schema( + [ + core_schema.model_schema( + cls=ModelA, + schema=core_schema.model_fields_schema( + fields={ + 'a': core_schema.model_field(core_schema.str_schema()), + 'b': core_schema.model_field(core_schema.str_schema()), + }, + ), + ), + core_schema.model_schema( + cls=ModelB, + schema=core_schema.model_fields_schema( + fields={ + 'c': core_schema.model_field(core_schema.str_schema()), + 'd': core_schema.model_field(core_schema.str_schema()), + }, + ), + ), + ] + ), + core_schema.union_schema( + [ + core_schema.model_schema( + cls=ModelCat, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), + }, + ), + ), + core_schema.model_schema( + cls=ModelDog, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), + }, + ), + ), + ] + ), + ] + ) + ) + + assert s.to_python(ModelA(a='a', b='b'), warnings='error') == {'a': 'a', 'b': 'b'} + assert s.to_python(ModelB(c='c', d='d'), warnings='error') == {'c': 'c', 'd': 'd'} + assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'} + assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}