Skip to content

Commit

Permalink
feat!: raise TypeError if config field does not match declared type
Browse files Browse the repository at this point in the history
This makes parsing stricter and could result in errors in some existing
configs. However, it allows for more precise deserialization, especially in case
of union types.
  • Loading branch information
eginhard committed Jan 10, 2025
1 parent b41b0e1 commit ffe6719
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 23 deletions.
36 changes: 27 additions & 9 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]:
Returns:
Dict: deserialized dictionary.
"""
if not isinstance(x, dict):
msg = f"Value `{x}` is not a dictionary"
raise TypeError(msg)
out_dict: dict[Any, Any] = {}
for k, v in x.items():
if v is None: # if {'key':None}
Expand All @@ -204,6 +207,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]:
Returns:
[List]: deserialized list.
"""
if not isinstance(x, list):
msg = f"Value `{x}` does not match field type `{field_type}`"
raise TypeError(msg)
field_args = typing.get_args(field_type)
if len(field_args) == 0:
return x
Expand Down Expand Up @@ -232,7 +238,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any:
try:
x = _deserialize(x, arg)
break
except ValueError:
except (TypeError, ValueError):
pass
return x

Expand All @@ -252,18 +258,30 @@ def _deserialize_primitive_types(
Returns:
Union[int, float, str, bool]: deserialized value.
"""
if isinstance(x, str | bool):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)

type_mismatch = f"Value `{x}` does not match field type `{field_type}`"
if x is None and type(None) in typing.get_args(field_type):
return None
if isinstance(x, str):
if base_type is not str:
raise TypeError(type_mismatch)
return x
if isinstance(x, bool):
if base_type is not bool:
raise TypeError(type_mismatch)
return x
if isinstance(x, int | float):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
if base_type is not int and base_type is not float:
raise TypeError(type_mismatch)
return base_type(x)
return None
raise TypeError(type_mismatch)


def _deserialize_path(x: Any, field_type: FieldType) -> Path | None:
Expand Down Expand Up @@ -299,8 +317,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:
return _deserialize_path(x, field_type)
if _is_primitive_type(_drop_none_type(field_type)):
return _deserialize_primitive_types(x, field_type)
msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type."
raise ValueError(msg)
msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type."
raise TypeError(msg)


CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
Expand Down
80 changes: 66 additions & 14 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from pathlib import Path
from types import UnionType
from typing import Any

import pytest

from coqpit.coqpit import Coqpit, _deserialize_list, _deserialize_primitive_types
from coqpit.coqpit import Coqpit, FieldType, _deserialize_list, _deserialize_primitive_types, _deserialize_union


@dataclass
Expand Down Expand Up @@ -65,36 +67,86 @@ def test_serialization() -> None:
def test_deserialize_list() -> None:
assert _deserialize_list([1, 2, 3], list) == [1, 2, 3]
assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3]
assert _deserialize_list([[1, 2, 3]], list[list[int]]) == [[1, 2, 3]]
assert _deserialize_list([1.0, 2.0, 3.0], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"]
assert _deserialize_list(["1", "2", "3"], list[str]) == ["1", "2", "3"]

with pytest.raises(TypeError, match="does not match field type"):
_deserialize_list([1, 2, 3], list[list[int]])

def test_deserialize_primitive_type() -> None:
cases = (

@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
(True, bool, True),
(False, bool, False),
("a", str, "a"),
("3", str, "3"),
(3, int, 3),
(3, float, 3.0),
(3, str, "3"),
(3.0, str, "3.0"),
(3, bool, True),
("a", str | None, "a"),
("3", str | None, "3"),
(3, int | None, 3),
(3, float | None, 3.0),
(None, str | None, None),
(None, int | None, None),
(None, float | None, None),
(None, str | None, None),
(None, bool | None, None),
(float("inf"), float, float("inf")),
(float("inf"), int, float("inf")),
(float("-inf"), float, float("-inf")),
(float("-inf"), int, float("-inf")),
)
for value, field_type, expected in cases:
assert _deserialize_primitive_types(value, field_type) == expected

with pytest.raises(TypeError):
_deserialize_primitive_types(3, Coqpit)
],
)
def test_deserialize_primitive_type(
value: str | bool | float | None,
field_type: FieldType,
expected: str | bool | float | None,
) -> None:
assert _deserialize_primitive_types(value, field_type) == expected


@pytest.mark.parametrize(
("value", "field_type"),
[
(3, str),
(3, str | None),
(3.0, str),
(3, bool),
("1", int),
("2.0", float),
("True", bool),
("True", bool | None),
("", bool | None),
([1, 2], str),
([1, 2, 3], int),
],
)
def test_deserialize_primitive_type_mismatch(
value: str | bool | float | None,
field_type: FieldType,
) -> None:
with pytest.raises(TypeError, match="does not match field type"):
_deserialize_primitive_types(value, field_type)


@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
("a", int | str, "a"),
("a", str | int, "a"),
(1, int | str, 1),
(1, str | int, 1),
(1, str | int | list[int], 1),
([1, 2], str | int | list[int], [1, 2]),
([1, 2], list[int] | int | str, [1, 2]),
([1, 2], dict | list, [1, 2]),
(["a", "b"], list[str] | list[list[str]], ["a", "b"]),
(["a", "b"], list[list[str]] | list[str], ["a", "b"]),
([["a", "b"]], list[str] | list[list[str]], [["a", "b"]]),
([["a", "b"]], list[list[str]] | list[str], [["a", "b"]]),
],
)
def test_deserialize_union(value: Any, field_type: UnionType, expected: Any) -> None:
assert _deserialize_union(value, field_type) == expected

0 comments on commit ffe6719

Please sign in to comment.