Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add dataclass_transform to maintain IDE typing support for EventedModel.__init__ #154

Merged
merged 9 commits into from
Dec 20, 2022
38 changes: 25 additions & 13 deletions src/psygnal/_evented_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import sys
import warnings
from contextlib import contextmanager
Expand All @@ -19,20 +17,29 @@

import pydantic.main
from pydantic import BaseModel, PrivateAttr, utils
from pydantic.fields import Field, FieldInfo

from ._evented_decorator import _check_field_equality, _pick_equality_operator
from ._group import SignalGroup
from ._signal import Signal, SignalInstance

if TYPE_CHECKING:
import inspect
from inspect import Signature

from pydantic import BaseConfig
from pydantic.fields import ModelField
from typing_extensions import dataclass_transform

ConfigType = Type[BaseConfig]
EqOperator = Callable[[Any, Any], bool]

else:
try:
from typing_extensions import dataclass_transform
except ImportError: # pragma: no cover

def dataclass_transform(*args, **kwargs):
return lambda a: a


_NULL = object()
ALLOW_PROPERTY_SETTERS = "allow_property_setters"
PROPERTY_DEPENDENCIES = "property_dependencies"
Expand Down Expand Up @@ -73,7 +80,7 @@ def no_class_attributes() -> Iterator[None]: # pragma: no cover

# monkey patch the pydantic ClassAttribute object
# the second argument to ClassAttribute is the inspect.Signature object
def _return2(x: str, y: inspect.Signature) -> inspect.Signature:
def _return2(x: str, y: "Signature") -> "Signature":
return y

pydantic.main.ClassAttribute = _return2 # type: ignore
Expand All @@ -84,6 +91,7 @@ def _return2(x: str, y: inspect.Signature) -> inspect.Signature:
pydantic.main.ClassAttribute = utils.ClassAttribute # type: ignore


@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
class EventedMetaclass(pydantic.main.ModelMetaclass):
"""pydantic ModelMetaclass that preps "equality checking" operations.

Expand All @@ -101,14 +109,14 @@ class EventedMetaclass(pydantic.main.ModelMetaclass):
@no_type_check
def __new__( # noqa: C901
mcs: type, name: str, bases: tuple, namespace: dict, **kwargs: Any
) -> EventedMetaclass:
) -> "EventedMetaclass":
"""Create new EventedModel class."""
with no_class_attributes():
cls = super().__new__(mcs, name, bases, namespace, **kwargs)

cls.__eq_operators__ = {}
signals = {}
fields: Dict[str, ModelField] = cls.__fields__
fields: Dict[str, "ModelField"] = cls.__fields__
for n, f in fields.items():
cls.__eq_operators__[n] = _pick_equality_operator(f.type_)
if f.field_info.allow_mutation:
Expand Down Expand Up @@ -153,7 +161,7 @@ def __new__( # noqa: C901
return cls


def _get_field_dependents(cls: EventedModel) -> Dict[str, Set[str]]: # noqa: C901
def _get_field_dependents(cls: "EventedModel") -> Dict[str, Set[str]]: # noqa: C901
"""Return mapping of field name -> dependent set of property names.

Dependencies may be declared in the Model Config to emit an event
Expand Down Expand Up @@ -288,14 +296,14 @@ class Config:
"""

# add private attributes for event emission
_events: SignalGroup = PrivateAttr()
_events: ClassVar[SignalGroup] = PrivateAttr()

# mapping of name -> property obj for methods that are property setters
__property_setters__: ClassVar[Dict[str, property]]
# mapping of field name -> dependent set of property names
# when field is changed, an event for dependent properties will be emitted.
__field_dependents__: ClassVar[Dict[str, Set[str]]]
__eq_operators__: ClassVar[Dict[str, EqOperator]]
__eq_operators__: ClassVar[Dict[str, "EqOperator"]]
__slots__ = {"__weakref__"}
__signal_group__: ClassVar[Type[SignalGroup]]
# pydantic BaseModel configuration. see:
Expand All @@ -307,7 +315,11 @@ class Config:

def __init__(_model_self_, **data: Any) -> None:
super().__init__(**data)
_model_self_._events = _model_self_.__signal_group__(_model_self_)
Group = _model_self_.__signal_group__
# the type error is "cannot assign to a class variable" ...
# but if we don't use `ClassVar`, then the `dataclass_transform` decorator
# will add _events: SignalGroup to the __init__ signature, for *all* user models
_model_self_._events = Group(_model_self_) # type: ignore [misc]

def _super_setattr_(self, name: str, value: Any) -> None:
# pydantic will raise a ValueError if extra fields are not allowed
Expand Down Expand Up @@ -365,7 +377,7 @@ def reset(self) -> None:
):
setattr(self, name, value)

def update(self, values: Union[EventedModel, dict], recurse: bool = True) -> None:
def update(self, values: Union["EventedModel", dict], recurse: bool = True) -> None:
"""Update a model in place.

Parameters
Expand Down