Skip to content

Commit

Permalink
Merge pull request #3 from seandstewart/seandstewart/optional-types
Browse files Browse the repository at this point in the history
fix: correct handling optional types
  • Loading branch information
seandstewart authored Oct 16, 2024
2 parents cba3aa8 + 79e431a commit 6bd7023
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 15 deletions.
12 changes: 9 additions & 3 deletions src/typelib/marshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class UnionMarshaller(AbstractMarshaller[UnionT], tp.Generic[UnionT]):
- [`UnionUnmarshaller`][typelib.unmarshals.routines.UnionUnmarshaller]
"""

__slots__ = ("stack", "ordered_routines")
__slots__ = ("stack", "ordered_routines", "nullable")

def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None):
"""Constructor.
Expand All @@ -274,19 +274,25 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
self.nullable = inspection.isoptionaltype(t)
self.ordered_routines = [self.context[typ] for typ in self.stack]

def __call__(self, val: UnionT) -> serdes.MarshalledValueT:
"""Unmarshal a value into the bound `UnionT`.
"""Marshal a value into the bound `UnionT`.
Args:
val: The input value to unmarshal.
Raises:
ValueError: If `val` cannot be marshalled via any member type.
"""
if self.nullable and val is None:
return val

for routine in self.ordered_routines:
with contextlib.suppress(ValueError, TypeError, SyntaxError):
with contextlib.suppress(
ValueError, TypeError, SyntaxError, AttributeError
):
unmarshalled = routine(val)
return unmarshalled

Expand Down
7 changes: 6 additions & 1 deletion src/typelib/unmarshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
if inspection.isoptionaltype(t):
self.stack = (self.stack[-1], *self.stack[:-1])

self.ordered_routines = [self.context[typ] for typ in self.stack]

def __call__(self, val: tp.Any) -> UnionT:
Expand All @@ -690,7 +693,9 @@ def __call__(self, val: tp.Any) -> UnionT:
ValueError: If `val` cannot be unmarshalled into any member type.
"""
for routine in self.ordered_routines:
with contextlib.suppress(ValueError, TypeError, SyntaxError):
with contextlib.suppress(
ValueError, TypeError, SyntaxError, AttributeError
):
unmarshalled = routine(val)
return unmarshalled

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/marshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
given_input=2,
expected_output=2,
),
optional_none=dict(
given_type=typing.Optional[typing.Union[int, str]],
given_input=None,
expected_output=None,
),
datetime=dict(
given_type=datetime.datetime,
given_input=datetime.datetime.fromtimestamp(0, datetime.timezone.utc),
Expand Down
40 changes: 29 additions & 11 deletions tests/unit/marshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_date_marshaller(given_input, expected_output):
expected_output=datetime.datetime(1969, 12, 31).isoformat(),
),
)
def test_datetime_unmarshaller(given_input, expected_output):
def test_datetime_marshaller(given_input, expected_output):
# Given
given_marshaller = routines.DateTimeMarshaller(datetime.datetime, {})
# When
Expand All @@ -141,7 +141,7 @@ def test_datetime_unmarshaller(given_input, expected_output):
expected_output="00:00:00+00:00",
),
)
def test_time_unmarshaller(given_input, expected_output):
def test_time_marshaller(given_input, expected_output):
# Given
given_marshaller = routines.TimeMarshaller(datetime.time, {})
# When
Expand All @@ -153,7 +153,7 @@ def test_time_unmarshaller(given_input, expected_output):
@pytest.mark.suite(
timedelta=dict(given_input=datetime.timedelta(seconds=1), expected_output="PT1S"),
)
def test_timedelta_unmarshaller(given_input, expected_output):
def test_timedelta_marshaller(given_input, expected_output):
# Given
given_marshaller = routines.TimeDeltaMarshaller(datetime.timedelta, {})
# When
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_mapping_marshaller(given_input, expected_output):
expected_output=["field", "value"],
),
)
def test_iterable_unmarshaller(given_input, expected_output):
def test_iterable_marshaller(given_input, expected_output):
# Given
given_marshaller = routines.IterableMarshaller(typing.Iterable, {})
# When
Expand Down Expand Up @@ -259,8 +259,26 @@ def test_literal_marshaller(given_input, given_literal, given_context, expected_
},
expected_output=1,
),
optional_date_none=dict(
given_input=None,
given_union=typing.Optional[datetime.date],
given_context={
datetime.date: routines.DateMarshaller(datetime.date, {}),
type(None): routines.NoOpMarshaller(type(None), {}),
},
expected_output=None,
),
optional_date_date=dict(
given_input=datetime.date.today(),
given_union=typing.Optional[datetime.date],
given_context={
datetime.date: routines.DateMarshaller(datetime.date, {}),
type(None): routines.NoOpMarshaller(type(None), {}),
},
expected_output=datetime.date.today().isoformat(),
),
)
def test_union_unmarshaller(given_input, given_union, given_context, expected_output):
def test_union_marshaller(given_input, given_union, given_context, expected_output):
# Given
given_marshaller = routines.UnionMarshaller(given_union, given_context)
# When
Expand All @@ -280,7 +298,7 @@ def test_union_unmarshaller(given_input, given_union, given_context, expected_ou
expected_output={"field": 1},
),
)
def test_subscripted_mapping_unmarshaller(
def test_subscripted_mapping_marshaller(
given_input, given_mapping, given_context, expected_output
):
# Given
Expand Down Expand Up @@ -373,7 +391,7 @@ def test_subscripted_iterable_marshaller(
expected_output=["field", 1],
),
)
def test_fixed_tuple_unmarshaller(
def test_fixed_tuple_marshaller(
given_input, given_tuple, given_context, expected_output
):
# Given
Expand Down Expand Up @@ -419,7 +437,7 @@ def test_fixed_tuple_unmarshaller(
given_input=models.TDict(field="data", value=1),
),
)
def test_structured_type_unmarshaller(
def test_structured_type_marshaller(
given_input, given_cls, given_context, expected_output
):
# Given
Expand Down Expand Up @@ -456,12 +474,12 @@ def test_invalid_union():
given_marshaller(given_value)


def test_enum_unmarshaller():
def test_enum_marshaller():
# Given
given_unmarshaller = routines.EnumMarshaller(models.GivenEnum, {})
given_marshaller = routines.EnumMarshaller(models.GivenEnum, {})
given_value = models.GivenEnum.one
expected_value = models.GivenEnum.one.value
# When
unmarshalled = given_unmarshaller(given_value)
unmarshalled = given_marshaller(given_value)
# Then
assert unmarshalled == expected_value
5 changes: 5 additions & 0 deletions tests/unit/unmarshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@
timestamp=datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
),
),
optional_none=dict(
given_type=typing.Optional[typing.Union[int, str]],
given_input=None,
expected_output=None,
),
attrib_conflict=dict(
given_type=models.Parent,
given_input={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/unmarshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,24 @@ def test_literal_unmarshaller(
},
expected_output=1,
),
optional_date_none=dict(
given_input=None,
given_union=typing.Optional[datetime.date],
given_context={
datetime.date: routines.DateUnmarshaller(datetime.date, {}),
type(None): routines.NoOpUnmarshaller(type(None), {}),
},
expected_output=None,
),
optional_date_date=dict(
given_input=datetime.date.today().isoformat(),
given_union=typing.Optional[datetime.date],
given_context={
datetime.date: routines.DateUnmarshaller(datetime.date, {}),
type(None): routines.NoneTypeUnmarshaller(type(None), {}),
},
expected_output=datetime.date.today(),
),
)
def test_union_unmarshaller(given_input, given_union, given_context, expected_output):
# Given
Expand Down

0 comments on commit 6bd7023

Please sign in to comment.