Skip to content

Commit

Permalink
feat: support for pydantic v2 as fields (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 authored Feb 24, 2024
1 parent 8f44f34 commit 0a6f41a
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/psygnal/containers/_evented_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Mapping,
Expand All @@ -13,6 +15,7 @@
Type,
TypeVar,
Union,
get_args,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,6 +97,24 @@ def copy(self) -> Self:
def __copy__(self) -> Self:
return self.copy()

# PYDANTIC SUPPORT

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: Callable
) -> Mapping[str, Any]:
"""Return the Pydantic core schema for this object."""
from pydantic_core import core_schema

args = get_args(source_type)
return core_schema.no_info_after_validator_function(
function=cls,
schema=core_schema.dict_schema(
keys_schema=handler(args[0]) if args else None,
values_schema=handler(args[1]) if len(args) > 1 else None,
),
)


class DictEvents(SignalGroup):
"""Events available on [EventedDict][psygnal.containers.EventedDict].
Expand Down
22 changes: 21 additions & 1 deletion src/psygnal/containers/_evented_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@
cover this in test_evented_list.py)
"""

from __future__ import annotations # pragma: no cover
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Mapping,
MutableSequence,
TypeVar,
Union,
cast,
get_args,
overload,
)

Expand Down Expand Up @@ -422,3 +425,20 @@ def _reemit_child_event(self, *args: Any) -> None:
emitter, args = args[0]

self.events.child_event.emit(idx, obj, emitter, args)

# PYDANTIC SUPPORT

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: Callable
) -> Mapping[str, Any]:
"""Return the Pydantic core schema for this object."""
from pydantic_core import core_schema

args = get_args(source_type)
return core_schema.no_info_after_validator_function(
function=cls,
schema=core_schema.list_schema(
items_schema=handler(args[0]) if args else None,
),
)
30 changes: 29 additions & 1 deletion src/psygnal/containers/_evented_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

import inspect
from itertools import chain
from typing import TYPE_CHECKING, Any, Final, Iterable, Iterator, MutableSet, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Iterable,
Iterator,
Mapping,
MutableSet,
TypeVar,
get_args,
)

from psygnal import Signal, SignalGroup

Expand Down Expand Up @@ -156,6 +167,23 @@ def union(self, *s: Iterable[_T]) -> Self:
new.update(*s)
return new

# PYDANTIC SUPPORT

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: Callable
) -> Mapping[str, Any]:
"""Return the Pydantic core schema for this object."""
from pydantic_core import core_schema

args = get_args(source_type)
return core_schema.no_info_after_validator_function(
function=cls,
schema=core_schema.set_schema(
items_schema=handler(args[0]) if args else None,
),
)


class OrderedSet(_BaseMutableSet[_T]):
"""A set that preserves insertion order, uses dict behind the scenes."""
Expand Down
101 changes: 101 additions & 0 deletions tests/test_pydantic_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Any, get_origin

import pytest

try:
import pydantic
except ImportError:
pytest.skip("pydantic not installed", allow_module_level=True)

from psygnal import containers

V1 = pydantic.__version__.startswith("1")


@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics")
@pytest.mark.parametrize(
"hint",
[
containers.EventedList[int],
containers.SelectableEventedList[int],
],
)
def test_evented_list_as_pydantic_field(hint: Any) -> None:
class Model(pydantic.BaseModel):
my_list: hint

m = Model(my_list=[1, 2, 3]) # type: ignore
assert m.my_list == [1, 2, 3]
assert isinstance(m.my_list, get_origin(hint))

m2 = Model(my_list=containers.EventedList([1, 2, 3]))
assert m2.my_list == [1, 2, 3]
m3 = Model(my_list=[1, "2", 3]) # type: ignore
assert m3.my_list == [1, 2, 3]
assert isinstance(m3.my_list, get_origin(hint))

with pytest.raises(pydantic.ValidationError):
Model(my_list=[1, 2, "string"]) # type: ignore


@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics")
def test_evented_list_no_params_as_pydantic_field() -> None:
class Model(pydantic.BaseModel):
my_list: containers.EventedList

m = Model(my_list=[1, 2, 3]) # type: ignore
assert m.my_list == [1, 2, 3]
assert isinstance(m.my_list, containers.EventedList)

m3 = Model(my_list=[1, "string", 3]) # type: ignore
assert m3.my_list == [1, "string", 3]
assert isinstance(m3.my_list, containers.EventedList)


@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics")
@pytest.mark.parametrize(
"hint",
[
containers.EventedSet[str],
containers.EventedOrderedSet[str],
containers.Selection[str],
],
)
def test_evented_set_as_pydantic_field(hint: Any) -> None:
class Model(pydantic.BaseModel):
my_set: hint

model_config = {"coerce_numbers_to_str": True}

m = Model(my_set=[1, 2]) # type: ignore
assert m.my_set == {"1", "2"} # type: ignore
assert isinstance(m.my_set, get_origin(hint))

m2 = Model(my_set=containers.EventedSet(["a", "b"]))
assert m2.my_set == {"a", "b"} # type: ignore
m3 = Model(my_set=[1, "2", 3]) # type: ignore
assert m3.my_set == {"1", "2", "3"} # type: ignore
assert isinstance(m3.my_set, get_origin(hint))


@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics")
def test_evented_dict_as_pydantic_field() -> None:
class Model(pydantic.BaseModel):
my_dict: containers.EventedDict[str, int]

model_config = {"coerce_numbers_to_str": True}

m = Model(my_dict={"a": 1}) # type: ignore
assert m.my_dict == {"a": 1}
assert isinstance(m.my_dict, containers.EventedDict)

m2 = Model(my_dict=containers.EventedDict({"a": 1}))
assert m2.my_dict == {"a": 1}
assert isinstance(m2.my_dict, containers.EventedDict)

m3 = Model(my_dict={1: "2"}) # type: ignore
assert m3.my_dict == {"1": 2}
assert isinstance(m3.my_dict, containers.EventedDict)

with pytest.raises(pydantic.ValidationError):
Model(my_dict={"a": "string"}) # type: ignore

0 comments on commit 0a6f41a

Please sign in to comment.