diff --git a/src/cattr/converters.py b/src/cattr/converters.py index aa08ab35..c7a1c60e 100644 --- a/src/cattr/converters.py +++ b/src/cattr/converters.py @@ -405,9 +405,13 @@ def _structure_call(obj, cl): @staticmethod def _structure_literal(val, type): - if val not in type.__args__: - raise Exception(f"{val} not in literal {type}") - return val + vals = { + (x.value if isinstance(x, Enum) else x): x for x in type.__args__ + } + try: + return vals[val] + except KeyError: + raise Exception(f"{val} not in literal {type}") from None # Attrs classes. diff --git a/tests/test_structure_attrs.py b/tests/test_structure_attrs.py index dbd5a6c1..712de2ce 100644 --- a/tests/test_structure_attrs.py +++ b/tests/test_structure_attrs.py @@ -1,4 +1,5 @@ """Loading of attrs classes.""" +from enum import Enum from ipaddress import IPv4Address, IPv6Address, ip_address from typing import Union from unittest.mock import Mock @@ -164,6 +165,27 @@ class ClassWithLiteral: ) == ClassWithLiteral(4) +@pytest.mark.skipif(is_py37, reason="Not supported on 3.7") +@pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) +def test_structure_literal_enum(converter_cls): + """Structuring a class with a literal field works.""" + from typing import Literal + + converter = converter_cls() + + class Foo(Enum): + FOO = 1 + BAR = 2 + + @define + class ClassWithLiteral: + literal_field: Literal[Foo.FOO] = Foo.FOO + + assert converter.structure( + {"literal_field": 1}, ClassWithLiteral + ) == ClassWithLiteral(Foo.FOO) + + @pytest.mark.skipif(is_py37, reason="Not supported on 3.7") @pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) def test_structure_literal_multiple(converter_cls): @@ -172,9 +194,17 @@ def test_structure_literal_multiple(converter_cls): converter = converter_cls() + class Foo(Enum): + FOO = 7 + FOOFOO = 77 + + class Bar(int, Enum): + BAR = 8 + BARBAR = 88 + @define class ClassWithLiteral: - literal_field: Literal[4, 5] = 4 + literal_field: Literal[4, 5, Foo.FOO, Bar.BARBAR] = 4 assert converter.structure( {"literal_field": 4}, ClassWithLiteral @@ -183,6 +213,18 @@ class ClassWithLiteral: {"literal_field": 5}, ClassWithLiteral ) == ClassWithLiteral(5) + cwl = converter.structure( + {"literal_field": 7}, ClassWithLiteral + ) + assert cwl ==ClassWithLiteral(Foo.FOO) + assert isinstance(cwl.literal_field, Foo) + + cwl = converter.structure( + {"literal_field": 88}, ClassWithLiteral + ) + assert cwl ==ClassWithLiteral(Bar.BARBAR) + assert isinstance(cwl.literal_field, Bar) + @pytest.mark.skipif(is_py37, reason="Not supported on 3.7") @pytest.mark.parametrize("converter_cls", [Converter, GenConverter])