Skip to content

Commit 598344f

Browse files
committed
fix: handle constrained datetimes
1 parent 6ab574a commit 598344f

File tree

4 files changed

+142
-11
lines changed

4 files changed

+142
-11
lines changed

polyfactory/factories/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
handle_constrained_collection,
7474
handle_constrained_mapping,
7575
)
76-
from polyfactory.value_generators.constrained_dates import handle_constrained_date
76+
from polyfactory.value_generators.constrained_dates import handle_constrained_date, handle_constrained_datetime
7777
from polyfactory.value_generators.constrained_numbers import (
7878
handle_constrained_decimal,
7979
handle_constrained_float,
@@ -624,15 +624,15 @@ def create_factory(
624624
)
625625

626626
@classmethod
627-
def get_constrained_field_value( # noqa: C901, PLR0911
627+
def get_constrained_field_value( # noqa: C901, PLR0911, PLR0912
628628
cls,
629629
annotation: Any,
630630
field_meta: FieldMeta,
631631
field_build_parameters: Any | None = None,
632632
build_context: BuildContext | None = None,
633633
) -> Any:
634+
constraints = cast("Constraints", field_meta.constraints)
634635
try:
635-
constraints = cast("Constraints", field_meta.constraints)
636636
if is_safe_subclass(annotation, float):
637637
return handle_constrained_float(
638638
random=cls.__random__,
@@ -705,6 +705,16 @@ def get_constrained_field_value( # noqa: C901, PLR0911
705705
build_context=build_context,
706706
)
707707

708+
if is_safe_subclass(annotation, datetime):
709+
return handle_constrained_datetime(
710+
faker=cls.__faker__,
711+
ge=cast("Any", constraints.get("ge")),
712+
gt=cast("Any", constraints.get("gt")),
713+
le=cast("Any", constraints.get("le")),
714+
lt=cast("Any", constraints.get("lt")),
715+
tz=cast("Any", constraints.get("tz")),
716+
)
717+
708718
if is_safe_subclass(annotation, date):
709719
return handle_constrained_date(
710720
faker=cls.__faker__,

polyfactory/value_generators/constrained_dates.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,49 @@ def handle_constrained_date(
2727
:returns: A date instance.
2828
"""
2929
start_date = datetime.now(tz=tz).date() - timedelta(days=100)
30-
if ge:
30+
if ge is not None:
3131
start_date = ge
32-
elif gt:
32+
elif gt is not None:
3333
start_date = gt + timedelta(days=1)
3434

3535
end_date = datetime.now(tz=timezone.utc).date() + timedelta(days=100)
36-
if le:
36+
if le is not None:
3737
end_date = le
38-
elif lt:
38+
elif lt is not None:
3939
end_date = lt - timedelta(days=1)
4040

4141
return faker.date_between(start_date=start_date, end_date=end_date)
42+
43+
44+
def handle_constrained_datetime(
45+
faker: Faker,
46+
ge: datetime | None = None,
47+
gt: datetime | None = None,
48+
le: datetime | None = None,
49+
lt: datetime | None = None,
50+
tz: tzinfo | None = None,
51+
) -> datetime:
52+
"""Generates a datetime value fulfilling the expected constraints.
53+
54+
:param faker: An instance of faker.
55+
:param lt: Less than value.
56+
:param le: Less than or equal value.
57+
:param gt: Greater than value.
58+
:param ge: Greater than or equal value.
59+
:param tz: A timezone. If not provided, infers from constraint values.
60+
61+
:returns: A datetime instance.
62+
"""
63+
start_datetime = datetime.now(tz=tz) - timedelta(days=100)
64+
if ge:
65+
start_datetime = ge
66+
elif gt:
67+
start_datetime = gt + timedelta(seconds=1)
68+
69+
end_datetime = datetime.now(tz=tz) + timedelta(days=100)
70+
if le is not None:
71+
end_datetime = le
72+
elif lt is not None:
73+
end_datetime = lt - timedelta(seconds=1)
74+
75+
return faker.date_time_between(start_date=start_datetime, end_date=end_datetime, tzinfo=tz)

tests/constraints/test_date_constraints.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import date, timedelta
2-
from typing import Optional
32

43
import pytest
54
from hypothesis import given
@@ -14,10 +13,18 @@
1413
dates(max_value=date.today() - timedelta(days=3)),
1514
dates(min_value=date.today()),
1615
)
17-
@pytest.mark.parametrize(("start", "end"), (("ge", "le"), ("gt", "lt"), ("ge", "lt"), ("gt", "le")))
16+
@pytest.mark.parametrize(
17+
("start", "end"),
18+
(
19+
("ge", "le"),
20+
("gt", "lt"),
21+
("ge", "lt"),
22+
("gt", "le"),
23+
),
24+
)
1825
def test_handle_constrained_date(
19-
start: Optional[str],
20-
end: Optional[str],
26+
start: str | None,
27+
end: str | None,
2128
start_date: date,
2229
end_date: date,
2330
) -> None:
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Test datetime constraints, including Issue #734."""
2+
3+
from datetime import datetime, timedelta, timezone
4+
from typing import Annotated
5+
6+
import pytest
7+
from annotated_types import Timezone
8+
from hypothesis import given
9+
from hypothesis.strategies import datetimes
10+
from typing_extensions import Literal
11+
12+
from pydantic import BaseModel, BeforeValidator, Field
13+
14+
from polyfactory.factories.pydantic_factory import ModelFactory
15+
16+
17+
@given(
18+
datetimes(min_value=datetime(1900, 1, 1), max_value=datetime.now() - timedelta(days=3)),
19+
datetimes(min_value=datetime.now(), max_value=datetime(2100, 1, 1)),
20+
)
21+
@pytest.mark.parametrize(
22+
("start", "end"),
23+
(
24+
("ge", "le"),
25+
("gt", "lt"),
26+
("ge", "lt"),
27+
("gt", "le"),
28+
),
29+
)
30+
def test_handle_constrained_datetime(
31+
start: Literal["ge", "gt"],
32+
end: Literal["le", "lt"],
33+
start_datetime: datetime,
34+
end_datetime: datetime,
35+
) -> None:
36+
"""Test that constrained datetimes are generated correctly."""
37+
if start_datetime == end_datetime:
38+
return
39+
40+
kwargs: dict[Literal["ge", "gt", "le", "lt"], datetime] = {}
41+
if start:
42+
kwargs[start] = start_datetime
43+
if end:
44+
kwargs[end] = end_datetime
45+
46+
class MyModel(BaseModel):
47+
value: datetime = Field(**kwargs) # type: ignore
48+
49+
class MyFactory(ModelFactory[MyModel]): ...
50+
51+
result = MyFactory.build()
52+
53+
assert result.value
54+
assert isinstance(result.value, datetime), "Should be datetime.datetime, not date"
55+
assert result.value >= start_datetime if "ge" in kwargs else result.value > start_datetime
56+
assert result.value <= end_datetime if "le" in kwargs else result.value < end_datetime
57+
58+
59+
def validate_datetime(value: datetime) -> datetime:
60+
"""Validator that expects a datetime object with timezone info."""
61+
assert isinstance(value, datetime), f"Expected datetime.datetime, got {type(value)}"
62+
assert value.tzinfo == timezone.utc, f"Expected UTC timezone, got {value.tzinfo}"
63+
return value
64+
65+
66+
ValidatedDatetime = Annotated[datetime, BeforeValidator(validate_datetime), Timezone(tz=timezone.utc)]
67+
68+
69+
def test_annotated_datetime_with_validator_and_constraint() -> None:
70+
minimum_datetime = datetime(2030, 1, 1, tzinfo=timezone.utc)
71+
72+
class MyModel(BaseModel):
73+
dt: ValidatedDatetime = Field(gt=minimum_datetime)
74+
75+
class MyModelFactory(ModelFactory[MyModel]): ...
76+
77+
instance = MyModelFactory.build()
78+
assert isinstance(instance.dt, datetime), "Should be datetime.datetime"
79+
assert instance.dt.tzinfo == timezone.utc, "Should have UTC timezone"
80+
assert instance.dt > minimum_datetime, "Should respect gt constraint"

0 commit comments

Comments
 (0)