Skip to content

Commit 83d0308

Browse files
committed
fix: handle constrained datetimes
1 parent 598344f commit 83d0308

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

tests/constraints/test_date_constraints.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import date, timedelta
2+
from typing import Optional
23

34
import pytest
45
from hypothesis import given
@@ -23,8 +24,8 @@
2324
),
2425
)
2526
def test_handle_constrained_date(
26-
start: str | None,
27-
end: str | None,
27+
start: Optional[str],
28+
end: Optional[str],
2829
start_date: date,
2930
end_date: date,
3031
) -> None:

tests/constraints/test_datetime_constraints.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""Test datetime constraints, including Issue #734."""
22

3+
import contextlib
34
from datetime import datetime, timedelta, timezone
4-
from typing import Annotated
5+
from typing import Annotated, Optional
56

67
import pytest
78
from annotated_types import Timezone
89
from hypothesis import given
910
from hypothesis.strategies import datetimes
10-
from typing_extensions import Literal
1111

12-
from pydantic import BaseModel, BeforeValidator, Field
12+
from pydantic import BaseModel, Field, __version__
13+
14+
with contextlib.suppress(ImportError):
15+
from pydantic import BeforeValidator
16+
1317

1418
from polyfactory.factories.pydantic_factory import ModelFactory
1519

@@ -28,16 +32,16 @@
2832
),
2933
)
3034
def test_handle_constrained_datetime(
31-
start: Literal["ge", "gt"],
32-
end: Literal["le", "lt"],
35+
start: Optional[str],
36+
end: Optional[str],
3337
start_datetime: datetime,
3438
end_datetime: datetime,
3539
) -> None:
3640
"""Test that constrained datetimes are generated correctly."""
3741
if start_datetime == end_datetime:
3842
return
3943

40-
kwargs: dict[Literal["ge", "gt", "le", "lt"], datetime] = {}
44+
kwargs: dict[str, datetime] = {}
4145
if start:
4246
kwargs[start] = start_datetime
4347
if end:
@@ -56,21 +60,19 @@ class MyFactory(ModelFactory[MyModel]): ...
5660
assert result.value <= end_datetime if "le" in kwargs else result.value < end_datetime
5761

5862

59-
def validate_datetime(value: datetime) -> datetime:
60-
"""Validator that expects a datetime object with timezone info."""
61-
assert isinstance(value, datetime), f"Expected datetime.datetime, got {type(value)}"
62-
assert value.tzinfo == timezone.utc, f"Expected UTC timezone, got {value.tzinfo}"
63-
return value
64-
65-
66-
ValidatedDatetime = Annotated[datetime, BeforeValidator(validate_datetime), Timezone(tz=timezone.utc)]
67-
68-
63+
@pytest.mark.skipif(__version__.startswith("1"), reason="Pydantic v2 required")
6964
def test_annotated_datetime_with_validator_and_constraint() -> None:
65+
def validate_datetime(value: datetime) -> datetime:
66+
"""Validator that expects a datetime object with timezone info."""
67+
assert isinstance(value, datetime), f"Expected datetime.datetime, got {type(value)}"
68+
assert value.tzinfo == timezone.utc, f"Expected UTC timezone, got {value.tzinfo}"
69+
return value
70+
71+
ValidatedDatetime = Annotated[datetime, BeforeValidator(validate_datetime), Timezone(tz=timezone.utc)]
7072
minimum_datetime = datetime(2030, 1, 1, tzinfo=timezone.utc)
7173

7274
class MyModel(BaseModel):
73-
dt: ValidatedDatetime = Field(gt=minimum_datetime)
75+
dt: ValidatedDatetime = Field(gt=minimum_datetime) # pyright: ignore[reportInvalidTypeForm]
7476

7577
class MyModelFactory(ModelFactory[MyModel]): ...
7678

0 commit comments

Comments
 (0)