Skip to content

Commit

Permalink
Use get_protocol_members in protocol checking
Browse files Browse the repository at this point in the history
This changes `check_protocol()` to make use of `get_protocol_members`
from typing-extensions. This allows removing an existing hard-coded
exclusion list for attributes existing on Protocol, but also handles the
cases `__orig_bases__` and `__weakref__` that was breaking when checking
intersecting protocols (a subclass of two or more protocols).

This has the effect of turning some false positives into true negatives,
but it also leaves some false negatives. To make that clear, xfail test
cases are added for the resulting false negatives.
  • Loading branch information
antonagestam committed Sep 21, 2024
1 parent ac7ac34 commit 740574f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Version history
This library adheres to
`Semantic Versioning 2.0 <https://semver.org/#semantic-versioning-200>`_.

**UNRELEASED**

- Fixed basic support for intersection protocols
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)

**4.3.0** (2024-05-27)

- Added support for checking against static protocols
Expand Down
22 changes: 8 additions & 14 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,19 +654,13 @@ def check_protocol(
else:
return

# Collect a set of methods and non-method attributes present in the protocol
ignored_attrs = set(dir(typing.Protocol)) | {
"__annotations__",
"__non_callable_proto_members__",
}
expected_methods: dict[str, tuple[Any, Any]] = {}
expected_noncallable_members: dict[str, Any] = {}
for attrname in dir(origin_type):
# Skip attributes present in typing.Protocol
if attrname in ignored_attrs:
continue
origin_annotations = typing.get_type_hints(origin_type)

for attrname in typing_extensions.get_protocol_members(origin_type):
member = getattr(origin_type, attrname, None)

member = getattr(origin_type, attrname)
if callable(member):
signature = inspect.signature(member)
argtypes = [
Expand All @@ -681,10 +675,10 @@ def check_protocol(
)
expected_methods[attrname] = argtypes, return_annotation
else:
expected_noncallable_members[attrname] = member

for attrname, annotation in typing.get_type_hints(origin_type).items():
expected_noncallable_members[attrname] = annotation
try:
expected_noncallable_members[attrname] = origin_annotations[attrname]
except KeyError:
expected_noncallable_members[attrname] = member

subject_annotations = typing.get_type_hints(subject)

Expand Down
56 changes: 56 additions & 0 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
Dict,
ForwardRef,
FrozenSet,
Iterable,
Iterator,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Protocol,
Sequence,
Set,
Sized,
TextIO,
Tuple,
Type,
Expand Down Expand Up @@ -995,6 +998,59 @@ def test_text_real_file(self, tmp_path: Path):
check_type(f, TextIO)


class TestIntersectingProtocol:
SIT = TypeVar("SIT", bound=object, covariant=True)

class SizedIterable(
Sized,
Iterable[SIT],
Protocol[SIT],
): ...

@pytest.mark.parametrize(
("subject", "predicate_type"),
(
((), SizedIterable),
(range(2), SizedIterable),
((), SizedIterable[int]),
((1, 2, 3), SizedIterable[int]),
(("1", "2", "3"), SizedIterable[str]),
),
)
def test_valid_member_passes(self, subject: object, predicate_type: type) -> None:
for _ in range(2): # Makes sure that the cache is also exercised
check_type(subject, predicate_type)

xfail_nested_protocol_checks = pytest.mark.xfail(
reason="false negative due to missing support for nested protocol checks",
)

@pytest.mark.parametrize(
("subject", "predicate_type"),
(
((1 for _ in ()), SizedIterable),
pytest.param(
range(2),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
),
pytest.param(
(1, 2, 3),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
),
pytest.param(
("1", "2", "3"),
SizedIterable[int],
marks=xfail_nested_protocol_checks,
),
),
)
def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None:
with pytest.raises(TypeCheckError):
check_type(subject, predicate_type)


@pytest.mark.parametrize(
"instantiate, annotation",
[
Expand Down

0 comments on commit 740574f

Please sign in to comment.