Skip to content

Commit

Permalink
Added proper Protocol method signature checking (#496)
Browse files Browse the repository at this point in the history
It's not good enough to pretend we can use `check_callable()` to check method signature compatibility.

Fixes #465.
  • Loading branch information
agronholm authored Oct 27, 2024
1 parent afad2c7 commit b72794d
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 178 deletions.
9 changes: 2 additions & 7 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,8 @@ As of version 4.3.0, Typeguard can check instances and classes against Protocols
regardless of whether they were annotated with
:func:`@runtime_checkable <typing.runtime_checkable>`.

There are several limitations on the checks performed, however:

* For non-callable members, only presence is checked for; no type compatibility checks
are performed
* For methods, only the number of positional arguments are checked against, so any added
keyword-only arguments without defaults don't currently trip the checker
* Likewise, argument types are not checked for compatibility
The only current limitation is that argument annotations are not checked for
compatibility, however this should be covered by static type checkers pretty well.

Special considerations for ``if TYPE_CHECKING:``
------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ This library adheres to

**UNRELEASED**

- Added proper checking for method signatures in protocol checks
(`#465 <https://github.com/agronholm/typeguard/pull/465>`_)
- Fixed basic support for intersection protocols
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)
- Fixed protocol checks running against the class of an instance and not the instance
itself (this produced wrong results for non-method member checks)

**4.3.0** (2024-05-27)

Expand Down
250 changes: 173 additions & 77 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from inspect import Parameter, isclass, isfunction
from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase
from itertools import zip_longest
from textwrap import indent
from typing import (
IO,
Expand All @@ -32,7 +33,6 @@
Union,
)
from unittest.mock import Mock
from weakref import WeakKeyDictionary

import typing_extensions

Expand Down Expand Up @@ -86,10 +86,6 @@
if sys.version_info >= (3, 9):
generic_alias_types += (types.GenericAlias,)

protocol_check_cache: WeakKeyDictionary[
type[Any], dict[type[Any], TypeCheckError | None]
] = WeakKeyDictionary()

# Sentinel
_missing = object()

Expand Down Expand Up @@ -638,96 +634,196 @@ def check_io(
raise TypeCheckError("is not an I/O object")


def check_protocol(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
def check_signature_compatible(
subject_callable: Callable[..., Any], protocol: type, attrname: str
) -> None:
subject: type[Any] = value if isclass(value) else type(value)
subject_sig = inspect.signature(subject_callable)
protocol_sig = inspect.signature(getattr(protocol, attrname))
protocol_type: typing.Literal["instance", "class", "static"] = "instance"
subject_type: typing.Literal["instance", "class", "static"] = "instance"

# Check if the protocol-side method is a class method or static method
if attrname in protocol.__dict__:
descriptor = protocol.__dict__[attrname]
if isinstance(descriptor, staticmethod):
protocol_type = "static"
elif isinstance(descriptor, classmethod):
protocol_type = "class"

# Check if the subject-side method is a class method or static method
if inspect.ismethod(subject_callable) and inspect.isclass(
subject_callable.__self__
):
subject_type = "class"
elif not hasattr(subject_callable, "__self__"):
subject_type = "static"

if subject in protocol_check_cache:
result_map = protocol_check_cache[subject]
if origin_type in result_map:
if exc := result_map[origin_type]:
raise exc
else:
return
if protocol_type == "instance" and subject_type != "instance":
raise TypeCheckError(
f"should be an instance method but it's a {subject_type} method"
)
elif protocol_type != "instance" and subject_type == "instance":
raise TypeCheckError(
f"should be a {protocol_type} method but it's an instance method"
)

expected_methods: dict[str, tuple[Any, Any]] = {}
expected_noncallable_members: dict[str, Any] = {}
origin_annotations = typing.get_type_hints(origin_type)
expected_varargs = any(
param
for param in protocol_sig.parameters.values()
if param.kind is Parameter.VAR_POSITIONAL
)
has_varargs = any(
param
for param in subject_sig.parameters.values()
if param.kind is Parameter.VAR_POSITIONAL
)
if expected_varargs and not has_varargs:
raise TypeCheckError("should accept variable positional arguments but doesn't")

protocol_has_varkwargs = any(
param
for param in protocol_sig.parameters.values()
if param.kind is Parameter.VAR_KEYWORD
)
subject_has_varkwargs = any(
param
for param in subject_sig.parameters.values()
if param.kind is Parameter.VAR_KEYWORD
)
if protocol_has_varkwargs and not subject_has_varkwargs:
raise TypeCheckError("should accept variable keyword arguments but doesn't")

# Check that the callable has at least the expect amount of positional-only
# arguments (and no extra positional-only arguments without default values)
if not has_varargs:
protocol_args = [
param
for param in protocol_sig.parameters.values()
if param.kind
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
]
subject_args = [
param
for param in subject_sig.parameters.values()
if param.kind
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
]

# Remove the "self" parameter from the protocol arguments to match
if protocol_type == "instance":
protocol_args.pop(0)

for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args):
if protocol_arg is None:
if subject_arg.default is Parameter.empty:
raise TypeCheckError("has too many mandatory positional arguments")

break

if subject_arg is None:
raise TypeCheckError("has too few positional arguments")

if (
protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD
and subject_arg.kind is Parameter.POSITIONAL_ONLY
):
raise TypeCheckError(
f"has an argument ({subject_arg.name}) that should not be "
f"positional-only"
)

if (
protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD
and protocol_arg.name != subject_arg.name
):
raise TypeCheckError(
f"has a positional argument ({subject_arg.name}) that should be "
f"named {protocol_arg.name!r} at this position"
)

for attrname in typing_extensions.get_protocol_members(origin_type):
member = getattr(origin_type, attrname, None)

if callable(member):
signature = inspect.signature(member)
argtypes = [
(p.annotation if p.annotation is not Parameter.empty else Any)
for p in signature.parameters.values()
if p.kind is not Parameter.KEYWORD_ONLY
] or Ellipsis
return_annotation = (
signature.return_annotation
if signature.return_annotation is not Parameter.empty
else Any
protocol_kwonlyargs = {
param.name: param
for param in protocol_sig.parameters.values()
if param.kind is Parameter.KEYWORD_ONLY
}
subject_kwonlyargs = {
param.name: param
for param in subject_sig.parameters.values()
if param.kind is Parameter.KEYWORD_ONLY
}
if not subject_has_varkwargs:
# Check that the signature has at least the required keyword-only arguments, and
# no extra mandatory keyword-only arguments
if missing_kwonlyargs := [
param.name
for param in protocol_kwonlyargs.values()
if param.name not in subject_kwonlyargs
]:
raise TypeCheckError(
"is missing keyword-only arguments: " + ", ".join(missing_kwonlyargs)
)
expected_methods[attrname] = argtypes, return_annotation
else:
try:
expected_noncallable_members[attrname] = origin_annotations[attrname]
except KeyError:
expected_noncallable_members[attrname] = member

subject_annotations = typing.get_type_hints(subject)
if not protocol_has_varkwargs:
if extra_kwonlyargs := [
param.name
for param in subject_kwonlyargs.values()
if param.default is Parameter.empty
and param.name not in protocol_kwonlyargs
]:
raise TypeCheckError(
"has mandatory keyword-only arguments not present in the protocol: "
+ ", ".join(extra_kwonlyargs)
)

# Check that all required methods are present and their signatures are compatible
result_map = protocol_check_cache.setdefault(subject, {})
try:
for attrname, callable_args in expected_methods.items():

def check_protocol(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
origin_annotations = typing.get_type_hints(origin_type)
for attrname in sorted(typing_extensions.get_protocol_members(origin_type)):
if (annotation := origin_annotations.get(attrname)) is not None:
try:
method = getattr(subject, attrname)
subject_member = getattr(value, attrname)
except AttributeError:
if attrname in subject_annotations:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} attribute is not a method"
) from None
else:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because it has no method named {attrname!r}"
) from None

if not callable(method):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} attribute is not a callable"
)
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no attribute named {attrname!r}"
) from None

# TODO: raise exception on added keyword-only arguments without defaults
try:
check_callable(method, Callable, callable_args, memo)
check_type_internal(subject_member, annotation, memo)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} method {exc}"
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute {exc}"
) from None
elif callable(getattr(origin_type, attrname)):
try:
subject_member = getattr(value, attrname)
except AttributeError:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because it has no method named {attrname!r}"
) from None

# Check that all required non-callable members are present
for attrname in expected_noncallable_members:
# TODO: implement assignability checks for non-callable members
if attrname not in subject_annotations and not hasattr(subject, attrname):
if not callable(subject_member):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because it has no attribute named {attrname!r}"
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} attribute is not a callable"
)
except TypeCheckError as exc:
result_map[origin_type] = exc
raise
else:
result_map[origin_type] = None

# TODO: implement assignability checks for parameter and return value
# annotations
try:
check_signature_compatible(subject_member, origin_type, attrname)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
f"protocol because its {attrname!r} method {exc}"
) from None


def check_byteslike(
Expand Down
15 changes: 0 additions & 15 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
List,
NamedTuple,
NewType,
Protocol,
TypeVar,
Union,
runtime_checkable,
)

T_Foo = TypeVar("T_Foo")
Expand Down Expand Up @@ -44,16 +42,3 @@ class Parent:
class Child(Parent):
def method(self, a: int) -> None:
pass


class StaticProtocol(Protocol):
member: int

def meth(self, x: str) -> None: ...


@runtime_checkable
class RuntimeProtocol(Protocol):
member: int

def meth(self, x: str) -> None: ...
Loading

0 comments on commit b72794d

Please sign in to comment.