Skip to content

Commit

Permalink
More union serialization tidying (#1536)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Nov 11, 2024
1 parent cd270e4 commit cd0346d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ install:
pip install -U pip wheel pre-commit
pip install -r tests/requirements.txt
pip install -r tests/requirements-linting.txt
pip install -e .
pip install -v -e .
pre-commit install

.PHONY: install-rust-coverage
Expand Down
31 changes: 22 additions & 9 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,19 +422,32 @@ impl TaggedUnionSerializer {
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
Discriminator::LookupKey(lookup_key) => lookup_key
.simple_py_get_attr(value)
.ok()
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
Discriminator::LookupKey(lookup_key) => {
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
// at this point. we could be more strict and only do this in lax mode...
let getattr_result = match value.is_instance_of::<PyDict>() {
true => {
let value_dict = value.downcast::<PyDict>().unwrap();
lookup_key.py_get_dict_item(value_dict).ok()
}
false => lookup_key.simple_py_get_attr(value).ok(),
};
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
}
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
if extra.check == SerCheck::None {
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
}
discriminator_value
}
Expand Down
62 changes: 62 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
assert m in str(w[0].message)


def test_mixed_union_models_and_other_types() -> None:
s = SchemaSerializer(
core_schema.union_schema(
[
core_schema.tagged_union_schema(
discriminator='type_',
choices={
'cat': core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
'dog': core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
},
),
core_schema.str_schema(),
]
)
)

assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'}
assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}
# note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings
# don't bubble up to this level :)
assert s.to_python('a string', warnings='error') == 'a string'


@pytest.mark.parametrize(
'input,expected',
[
Expand Down Expand Up @@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
)

assert s.to_json(input, warnings='error') == expected


def test_discriminated_union_ser_with_typed_dict() -> None:
v = SchemaSerializer(
core_schema.tagged_union_schema(
{
'a': core_schema.typed_dict_schema(
{
'type': core_schema.typed_dict_field(core_schema.literal_schema(['a'])),
'a': core_schema.typed_dict_field(core_schema.int_schema()),
}
),
'b': core_schema.typed_dict_schema(
{
'type': core_schema.typed_dict_field(core_schema.literal_schema(['b'])),
'b': core_schema.typed_dict_field(core_schema.str_schema()),
}
),
},
discriminator='type',
)
)

assert v.to_python({'type': 'a', 'a': 1}, warnings='error') == {'type': 'a', 'a': 1}
assert v.to_python({'type': 'b', 'b': 'foo'}, warnings='error') == {'type': 'b', 'b': 'foo'}

0 comments on commit cd0346d

Please sign in to comment.