Skip to content

Commit

Permalink
Further improved the implementation and removed protocol check caching
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Oct 27, 2024
1 parent 26fd117 commit d344b6e
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 129 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ This library adheres to
(`#465 <https://github.com/agronholm/typeguard/pull/465>`_)
- Fixed basic support for intersection protocols
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)
- Fixed protocol checks running against the class of an instance and not the instance
itself (this produced wrong results for non-method member checks)

**4.3.0** (2024-05-27)

Expand Down
134 changes: 68 additions & 66 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
Union,
)
from unittest.mock import Mock
from weakref import WeakKeyDictionary

import typing_extensions

Expand Down Expand Up @@ -87,10 +86,6 @@
if sys.version_info >= (3, 9):
generic_alias_types += (types.GenericAlias,)

protocol_check_cache: WeakKeyDictionary[
type[Any], dict[type[Any], tuple[Any, ...] | None]
] = WeakKeyDictionary()

# Sentinel
_missing = object()

Expand Down Expand Up @@ -644,6 +639,33 @@ def check_signature_compatible(
) -> None:
subject_sig = inspect.signature(subject_callable)
protocol_sig = inspect.signature(getattr(protocol, attrname))
protocol_type: typing.Literal["instance", "class", "static"] = "instance"
subject_type: typing.Literal["instance", "class", "static"] = "instance"

# Check if the protocol-side method is a class method or static method
if attrname in protocol.__dict__:
descriptor = protocol.__dict__[attrname]
if isinstance(descriptor, staticmethod):
protocol_type = "static"
elif isinstance(descriptor, classmethod):
protocol_type = "class"

# Check if the subject-side method is a class method or static method
if inspect.ismethod(subject_callable) and inspect.isclass(
subject_callable.__self__
):
subject_type = "class"
elif not hasattr(subject_callable, "__self__"):
subject_type = "static"

if protocol_type == "instance" and subject_type != "instance":
raise TypeCheckError(
f"should be an instance method but it's a {subject_type} method"
)
elif protocol_type != "instance" and subject_type == "instance":
raise TypeCheckError(
f"should be a {protocol_type} method but it's an instance method"
)

expected_varargs = any(
param
Expand Down Expand Up @@ -687,12 +709,9 @@ def check_signature_compatible(
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
]

# Remove the "self" parameter from methods
if inspect.ismethod(subject_callable) or inspect.ismethoddescriptor(
subject_callable
):
# Remove the "self" parameter from the protocol arguments to match
if protocol_type == "instance":
protocol_args.pop(0)
subject_args.pop(0)

for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args):
if protocol_arg is None:
Expand Down Expand Up @@ -763,65 +782,48 @@ def check_protocol(
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
subject: type[Any] = value if isclass(value) else type(value)
origin_annotations = typing.get_type_hints(origin_type)
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
if (annotation := origin_annotations.get(attrname)) is not None:
try:
subject_member = getattr(value, attrname)
except AttributeError:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no attribute named {attrname!r}"
) from None

if subject in protocol_check_cache:
result_map = protocol_check_cache[subject]
if origin_type in result_map:
if exc_args := result_map[origin_type]:
raise TypeCheckError(*exc_args)
else:
return
try:
check_type_internal(subject_member, annotation, memo)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute {exc}"
) from None
elif callable(getattr(origin_type, attrname)):
try:
subject_member = getattr(value, attrname)
except AttributeError:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no method named {attrname!r}"
) from None

origin_annotations = typing.get_type_hints(origin_type)
result_map = protocol_check_cache.setdefault(subject, {})
try:
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
if (annotation := origin_annotations.get(attrname)) is not None:
try:
subject_member = getattr(subject, attrname)
except AttributeError:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no attribute named {attrname!r}"
) from None
if not callable(subject_member):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute is not a callable"
)

try:
check_type_internal(subject_member, annotation, memo)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute {exc}"
) from None
elif callable(getattr(origin_type, attrname)):
try:
subject_member = getattr(subject, attrname)
except AttributeError:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no method named {attrname!r}"
) from None

if not callable(subject_member):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute is not a callable"
)

# TODO: implement assignability checks for parameter and return value
# annotations
try:
check_signature_compatible(subject_member, origin_type, attrname)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} method {exc}"
) from None
except TypeCheckError as exc:
result_map[origin_type] = exc.args
raise
else:
result_map[origin_type] = None
# TODO: implement assignability checks for parameter and return value
# annotations
try:
check_signature_compatible(subject_member, origin_type, attrname)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} method {exc}"
) from None


def check_byteslike(
Expand Down
146 changes: 83 additions & 63 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,20 +1137,23 @@ def my_static_method(cls, x: int, y: str) -> None:
def my_class_method(x: int, y: str) -> None:
pass

for _ in range(2): # Makes sure that the cache is also exercised
check_type(Foo(), MyProtocol)
check_type(Foo(), MyProtocol)

def test_missing_member(self) -> None:
@pytest.mark.parametrize("has_member", [True, False])
def test_member_checks(self, has_member: bool) -> None:
class MyProtocol(Protocol):
member: int

class Foo:
pass
def __init__(self, member: int):
if member:
self.member = member

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, obj, MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
if has_member:
check_type(Foo(1), MyProtocol)
else:
pytest.raises(TypeCheckError, check_type, Foo(0), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because it has no attribute named "
f"'member'"
)
Expand All @@ -1163,12 +1166,11 @@ def meth(self) -> None:
class Foo:
pass

for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because it has no method named "
f"'meth'"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because it has no method named "
f"'meth'"
)

def test_too_many_posargs(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1179,13 +1181,11 @@ class Foo:
def meth(self, x: str) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
f"many mandatory positional arguments"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
f"many mandatory positional arguments"
)

def test_wrong_posarg_name(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1196,13 +1196,11 @@ class Foo:
def meth(self, y: str) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
rf"^{qualified_name(obj)} is not compatible with the "
rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a "
rf"positional argument \(y\) that should be named 'x' at this position"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
rf"^{qualified_name(Foo)} is not compatible with the "
rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a "
rf"positional argument \(y\) that should be named 'x' at this position"
)

def test_too_few_posargs(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1213,13 +1211,11 @@ class Foo:
def meth(self) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
f"few positional arguments"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has too "
f"few positional arguments"
)

def test_no_varargs(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1230,13 +1226,11 @@ class Foo:
def meth(self) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"accept variable positional arguments but doesn't"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"accept variable positional arguments but doesn't"
)

def test_no_kwargs(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1247,13 +1241,11 @@ class Foo:
def meth(self) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"accept variable keyword arguments but doesn't"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"accept variable keyword arguments but doesn't"
)

def test_missing_kwarg(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1264,13 +1256,11 @@ class Foo:
def meth(self) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method is "
f"missing keyword-only arguments: x"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method is "
f"missing keyword-only arguments: x"
)

def test_extra_kwarg(self) -> None:
class MyProtocol(Protocol):
Expand All @@ -1281,13 +1271,43 @@ class Foo:
def meth(self, *, x: str) -> None:
pass

obj = Foo()
for _ in range(2): # Makes sure that the cache is also exercised
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(obj)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has "
f"mandatory keyword-only arguments not present in the protocol: x"
)
pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method has "
f"mandatory keyword-only arguments not present in the protocol: x"
)

def test_instance_staticmethod_mismatch(self) -> None:
class MyProtocol(Protocol):
@staticmethod
def meth() -> None:
pass

class Foo:
def meth(self) -> None:
pass

pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"be a static method but it's an instance method"
)

def test_instance_classmethod_mismatch(self) -> None:
class MyProtocol(Protocol):
@classmethod
def meth(cls) -> None:
pass

class Foo:
def meth(self) -> None:
pass

pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match(
f"^{qualified_name(Foo)} is not compatible with the "
f"{MyProtocol.__qualname__} protocol because its 'meth' method should "
f"be a class method but it's an instance method"
)


class TestRecursiveType:
Expand Down

0 comments on commit d344b6e

Please sign in to comment.