diff --git a/src/psygnal/containers/_evented_dict.py b/src/psygnal/containers/_evented_dict.py index 3e45c704..0887acc3 100644 --- a/src/psygnal/containers/_evented_dict.py +++ b/src/psygnal/containers/_evented_dict.py @@ -4,6 +4,8 @@ from typing import ( TYPE_CHECKING, + Any, + Callable, Iterable, Iterator, Mapping, @@ -13,6 +15,7 @@ Type, TypeVar, Union, + get_args, ) if TYPE_CHECKING: @@ -94,6 +97,24 @@ def copy(self) -> Self: def __copy__(self) -> Self: return self.copy() + # PYDANTIC SUPPORT + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: Callable + ) -> Mapping[str, Any]: + """Return the Pydantic core schema for this object.""" + from pydantic_core import core_schema + + args = get_args(source_type) + return core_schema.no_info_after_validator_function( + function=cls, + schema=core_schema.dict_schema( + keys_schema=handler(args[0]) if args else None, + values_schema=handler(args[1]) if len(args) > 1 else None, + ), + ) + class DictEvents(SignalGroup): """Events available on [EventedDict][psygnal.containers.EventedDict]. diff --git a/src/psygnal/containers/_evented_list.py b/src/psygnal/containers/_evented_list.py index fbffe3cc..3d1525a8 100644 --- a/src/psygnal/containers/_evented_list.py +++ b/src/psygnal/containers/_evented_list.py @@ -22,16 +22,19 @@ cover this in test_evented_list.py) """ -from __future__ import annotations # pragma: no cover +from __future__ import annotations from typing import ( TYPE_CHECKING, Any, + Callable, Iterable, + Mapping, MutableSequence, TypeVar, Union, cast, + get_args, overload, ) @@ -422,3 +425,20 @@ def _reemit_child_event(self, *args: Any) -> None: emitter, args = args[0] self.events.child_event.emit(idx, obj, emitter, args) + + # PYDANTIC SUPPORT + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: Callable + ) -> Mapping[str, Any]: + """Return the Pydantic core schema for this object.""" + from pydantic_core import core_schema + + args = get_args(source_type) + return core_schema.no_info_after_validator_function( + function=cls, + schema=core_schema.list_schema( + items_schema=handler(args[0]) if args else None, + ), + ) diff --git a/src/psygnal/containers/_evented_set.py b/src/psygnal/containers/_evented_set.py index 50b55346..73770a7a 100644 --- a/src/psygnal/containers/_evented_set.py +++ b/src/psygnal/containers/_evented_set.py @@ -2,7 +2,18 @@ import inspect from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Iterable, Iterator, MutableSet, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Iterable, + Iterator, + Mapping, + MutableSet, + TypeVar, + get_args, +) from psygnal import Signal, SignalGroup @@ -156,6 +167,23 @@ def union(self, *s: Iterable[_T]) -> Self: new.update(*s) return new + # PYDANTIC SUPPORT + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: Callable + ) -> Mapping[str, Any]: + """Return the Pydantic core schema for this object.""" + from pydantic_core import core_schema + + args = get_args(source_type) + return core_schema.no_info_after_validator_function( + function=cls, + schema=core_schema.set_schema( + items_schema=handler(args[0]) if args else None, + ), + ) + class OrderedSet(_BaseMutableSet[_T]): """A set that preserves insertion order, uses dict behind the scenes.""" diff --git a/tests/test_pydantic_support.py b/tests/test_pydantic_support.py new file mode 100644 index 00000000..801f0293 --- /dev/null +++ b/tests/test_pydantic_support.py @@ -0,0 +1,101 @@ +from typing import Any, get_origin + +import pytest + +try: + import pydantic +except ImportError: + pytest.skip("pydantic not installed", allow_module_level=True) + +from psygnal import containers + +V1 = pydantic.__version__.startswith("1") + + +@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics") +@pytest.mark.parametrize( + "hint", + [ + containers.EventedList[int], + containers.SelectableEventedList[int], + ], +) +def test_evented_list_as_pydantic_field(hint: Any) -> None: + class Model(pydantic.BaseModel): + my_list: hint + + m = Model(my_list=[1, 2, 3]) # type: ignore + assert m.my_list == [1, 2, 3] + assert isinstance(m.my_list, get_origin(hint)) + + m2 = Model(my_list=containers.EventedList([1, 2, 3])) + assert m2.my_list == [1, 2, 3] + m3 = Model(my_list=[1, "2", 3]) # type: ignore + assert m3.my_list == [1, 2, 3] + assert isinstance(m3.my_list, get_origin(hint)) + + with pytest.raises(pydantic.ValidationError): + Model(my_list=[1, 2, "string"]) # type: ignore + + +@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics") +def test_evented_list_no_params_as_pydantic_field() -> None: + class Model(pydantic.BaseModel): + my_list: containers.EventedList + + m = Model(my_list=[1, 2, 3]) # type: ignore + assert m.my_list == [1, 2, 3] + assert isinstance(m.my_list, containers.EventedList) + + m3 = Model(my_list=[1, "string", 3]) # type: ignore + assert m3.my_list == [1, "string", 3] + assert isinstance(m3.my_list, containers.EventedList) + + +@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics") +@pytest.mark.parametrize( + "hint", + [ + containers.EventedSet[str], + containers.EventedOrderedSet[str], + containers.Selection[str], + ], +) +def test_evented_set_as_pydantic_field(hint: Any) -> None: + class Model(pydantic.BaseModel): + my_set: hint + + model_config = {"coerce_numbers_to_str": True} + + m = Model(my_set=[1, 2]) # type: ignore + assert m.my_set == {"1", "2"} # type: ignore + assert isinstance(m.my_set, get_origin(hint)) + + m2 = Model(my_set=containers.EventedSet(["a", "b"])) + assert m2.my_set == {"a", "b"} # type: ignore + m3 = Model(my_set=[1, "2", 3]) # type: ignore + assert m3.my_set == {"1", "2", "3"} # type: ignore + assert isinstance(m3.my_set, get_origin(hint)) + + +@pytest.mark.skipif(V1, reason="pydantic v1 has poor support for generics") +def test_evented_dict_as_pydantic_field() -> None: + class Model(pydantic.BaseModel): + my_dict: containers.EventedDict[str, int] + + model_config = {"coerce_numbers_to_str": True} + + m = Model(my_dict={"a": 1}) # type: ignore + assert m.my_dict == {"a": 1} + assert isinstance(m.my_dict, containers.EventedDict) + + m2 = Model(my_dict=containers.EventedDict({"a": 1})) + assert m2.my_dict == {"a": 1} + assert isinstance(m2.my_dict, containers.EventedDict) + + m3 = Model(my_dict={1: "2"}) # type: ignore + assert m3.my_dict == {"1": 2} + assert isinstance(m3.my_dict, containers.EventedDict) + + with pytest.raises(pydantic.ValidationError): + Model(my_dict={"a": "string"}) # type: ignore