From cd0346d68bf7c15dcd73b56d5084d094b0dea351 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 11 Nov 2024 10:37:28 -0500 Subject: [PATCH] More union serialization tidying (#1536) --- Makefile | 2 +- src/serializers/type_serializers/union.rs | 31 ++++++++---- tests/serializers/test_union.py | 62 +++++++++++++++++++++++ 3 files changed, 85 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index ad9de667b..4186369a7 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ install: pip install -U pip wheel pre-commit pip install -r tests/requirements.txt pip install -r tests/requirements-linting.txt - pip install -e . + pip install -v -e . pre-commit install .PHONY: install-rust-coverage diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index ac674ef34..35a5c8bc7 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -422,19 +422,32 @@ impl TaggedUnionSerializer { fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option> { let py = value.py(); let discriminator_value = match &self.discriminator { - Discriminator::LookupKey(lookup_key) => lookup_key - .simple_py_get_attr(value) - .ok() - .and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))), + Discriminator::LookupKey(lookup_key) => { + // we're pretty lax here, we allow either dict[key] or object.key, as we very well could + // be doing a discriminator lookup on a typed dict, and there's no good way to check that + // at this point. we could be more strict and only do this in lax mode... + let getattr_result = match value.is_instance_of::() { + true => { + let value_dict = value.downcast::().unwrap(); + lookup_key.py_get_dict_item(value_dict).ok() + } + false => lookup_key.simple_py_get_attr(value).ok(), + }; + getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))) + } Discriminator::Function(func) => func.call1(py, (value,)).ok(), }; if discriminator_value.is_none() { let value_str = truncate_safe_repr(value, None); - extra.warnings.custom_warning( - format!( - "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." - ) - ); + + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning + if extra.check == SerCheck::None { + extra.warnings.custom_warning( + format!( + "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." + ) + ); + } } discriminator_value } diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 66ec3b9b8..b77f8cb8f 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant( assert m in str(w[0].message) +def test_mixed_union_models_and_other_types() -> None: + s = SchemaSerializer( + core_schema.union_schema( + [ + core_schema.tagged_union_schema( + discriminator='type_', + choices={ + 'cat': core_schema.model_schema( + cls=ModelCat, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), + }, + ), + ), + 'dog': core_schema.model_schema( + cls=ModelDog, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), + }, + ), + ), + }, + ), + core_schema.str_schema(), + ] + ) + ) + + assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'} + assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'} + # note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings + # don't bubble up to this level :) + assert s.to_python('a string', warnings='error') == 'a string' + + @pytest.mark.parametrize( 'input,expected', [ @@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization( ) assert s.to_json(input, warnings='error') == expected + + +def test_discriminated_union_ser_with_typed_dict() -> None: + v = SchemaSerializer( + core_schema.tagged_union_schema( + { + 'a': core_schema.typed_dict_schema( + { + 'type': core_schema.typed_dict_field(core_schema.literal_schema(['a'])), + 'a': core_schema.typed_dict_field(core_schema.int_schema()), + } + ), + 'b': core_schema.typed_dict_schema( + { + 'type': core_schema.typed_dict_field(core_schema.literal_schema(['b'])), + 'b': core_schema.typed_dict_field(core_schema.str_schema()), + } + ), + }, + discriminator='type', + ) + ) + + assert v.to_python({'type': 'a', 'a': 1}, warnings='error') == {'type': 'a', 'a': 1} + assert v.to_python({'type': 'b', 'b': 'foo'}, warnings='error') == {'type': 'b', 'b': 'foo'}