Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: value_unit_pair module #95

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions python/ngen_conf/src/ngen/config/init_config/value_unit_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import re
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, Union

from pydantic import validator
from pydantic.generics import GenericModel
from typing_extensions import Self, override

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

V = TypeVar("V")
"""The type of `ValueUnitPair`'s `value` field."""
U = TypeVar("U")
"""The type of `ValueUnitPair`'s `unit` field."""


_VAL_UNIT_RE = re.compile(r"(.*)\[(.*)\]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I think we should now, but may want to consider more restrictive regex here in the future...especially considering captures which have spaces, non-standard characters, ect. This may be impossible to generalize and too far out of scope though -- just a passing thought.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. I didn't and don't know at this point in time what appropriate constraints to enforce here, so I left it a little open. My thinking was we could always refine this if needed.



class ValueUnitPair(GenericModel, Generic[V, U]):
value: V
unit: U

@override
@classmethod
def validate(cls: Type[Self], value: Any) -> Self:
if isinstance(value, ValueUnitPair):
# return a shallow copy. this also validates / coerces mismatching generic types
return cls(value=value.value, unit=value.unit)

# unpack kwargs like arguments into expected string form
if isinstance(value, dict):
v = value.get("value", Ellipsis)
u = value.get("unit", Ellipsis)
if v == Ellipsis or u == Ellipsis:
raise ValueError(f"cannot coerce value='{value!r}' into {cls.__name__}")
return cls(value=v, unit=u)

# cannot further coerce / validate value
if not isinstance(value, str):
raise ValueError(f"cannot coerce value='{value!r}' into {cls.__name__}")

match = _VAL_UNIT_RE.search(value)

if match is None:
raise ValueError(f"no match in str: {value!r}")

# examples
# 2[m] -> ("2", "m")
# 1,2,3,4[m/m] -> ("1,2,3,4", "m/m")
value, unit = match.groups()
return cls(value=value, unit=unit)

@override
@classmethod
def parse_obj(cls: Type[Self], obj: Any) -> Self:
return cls.validate(obj)

def _serialize(self) -> str:
return f"{str(self.value)}[{str(self.unit)}]"

@override
def dict(
self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
return self._serialize()


# NOTE: type constraint could be relaxed in the future, but would need modification
T = TypeVar("T", str, int, float, bool)
"""The type of `ListUnitPair`'s list items. Constrained to non-nullable json primitive types."""


class ListUnitPair(ValueUnitPair[List[T], U], Generic[T, U]):
@validator("value", pre=True)
def _coerce_values(cls, value: Union[str, List[str]]) -> List[str]:
if isinstance(value, list):
return value
if not isinstance(value, str):
raise ValueError(f"cannot coerce value='{value!r}' into list")

return value.split(",") if value else []

@override
def _serialize(self) -> str:
values = ",".join(map(str, self.value))
return f"{values}[{str(self.unit)}]"
124 changes: 124 additions & 0 deletions python/ngen_conf/tests/test_value_unit_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import typing

import pydantic
import pytest
from ngen.config.init_config.value_unit_pair import ListUnitPair, ValueUnitPair


@pytest.mark.parametrize(
"ty, value, unit",
(
(int, 42, "m"),
(float, 42.0, "m"),
(str, "42", "m"),
(bool, True, "m"),
),
)
def test_value_unit_pair_initialization(ty: type, value, unit: str):
o = ValueUnitPair[ty, typing.Literal[unit]](value=value, unit=unit)
assert o.value == value
assert o.unit == unit
assert type(o.value) == type(value)
assert type(o.unit) == type(unit)


@pytest.mark.parametrize(
"ty, value, unit",
(
(int, "not a number", "m"),
(float, 2j, "m"),
(str, 2j, "m"),
(bool, 2j, "m"),
),
)
def test_value_unit_pair_initialization_negative(ty: type, value, unit: str):
with pytest.raises(pydantic.ValidationError):
ValueUnitPair[ty, typing.Literal[unit]](value=value, unit=unit)


@pytest.mark.parametrize(
"ty, value, unit, expected",
(
(int, 42, "m", "42[m]"),
(float, 42.0, "m", "42.0[m]"),
(str, "42", "m", "42[m]"),
(bool, True, "m", "True[m]"),
),
)
def test_value_unit_pair_serialize(ty: type, value, unit: str, expected: str):
o = ValueUnitPair[ty, typing.Literal[unit]](value=value, unit=unit)
assert o.dict() == expected


@pytest.mark.parametrize(
"serial, expected",
(
("42[m]", ValueUnitPair[int, typing.Literal["m"]](value=42, unit="m")),
("42.0[m]", ValueUnitPair[float, typing.Literal["m"]](value=42.0, unit="m")),
("42[m]", ValueUnitPair[str, typing.Literal["m"]](value="42", unit="m")),
("True[m]", ValueUnitPair[bool, typing.Literal["m"]](value=True, unit="m")),
),
)
def test_value_unit_pair_from_str(serial: str, expected: ValueUnitPair):
assert expected.parse_obj(serial) == expected


@pytest.mark.parametrize(
"ty, value, unit",
(
(int, [1, 2, 3], "m"),
(float, [1.0, 2.0, 3.0], "m"),
(str, ["a", "b", "c"], "m"),
(bool, [True, False, True], "m"),
),
)
def test_list_unit_pair_initialization(ty: type, value, unit: str):
o = ListUnitPair[ty, typing.Literal[unit]](value=value, unit=unit)
assert o.value == value
assert o.unit == unit
assert type(o.value) == type(value)
assert type(o.unit) == type(unit)


@pytest.mark.parametrize(
"ty, value, unit, expected",
(
(int, [], "m", "[m]"),
(int, [1, 2, 3], "m", "1,2,3[m]"),
(int, "", "m", "[m]"),
(int, "1,2,3", "m", "1,2,3[m]"),
(int, " 1,2,3", "m", "1,2,3[m]"),
(int, "1,2,3 ", "m", "1,2,3[m]"),
(int, " 1,2,3 ", "m", "1,2,3[m]"),
),
)
def test_list_unit_pair_serialize(ty: type, value, unit: str, expected: str):
o = ListUnitPair[ty, typing.Literal[unit]](value=value, unit=unit)
assert o.dict() == expected


def test_generic_bounds_are_upheld():
m = typing.Literal["m"]
o = ValueUnitPair[str, m](value="42", unit="m")
assert isinstance(o.value, str)

o2 = ValueUnitPair[int, m].parse_obj(o)
assert isinstance(o2.value, int)


def test_generic_bounds_are_upheld_when_composed():
class Outer(pydantic.BaseModel):
inner: ValueUnitPair[int, typing.Literal["m"]]

o = ValueUnitPair[str, typing.Literal["m"]](value="42", unit="m")
o2 = Outer(inner=o)
assert isinstance(o2.inner.value, int)


def test_generic_bounds_are_upheld_negative():
m = typing.Literal["m"]
o = ValueUnitPair[str, m](value="not coercible to int", unit="m")
assert isinstance(o.value, str)

with pytest.raises(pydantic.ValidationError):
ValueUnitPair[int, m].parse_obj(o)
Loading