Skip to content

Commit fc8037d

Browse files
gh-104873: Add typing.get_protocol_members and typing.is_protocol (#104878)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent ba516e7 commit fc8037d

File tree

5 files changed

+152
-2
lines changed

5 files changed

+152
-2
lines changed

Doc/library/typing.rst

+32
Original file line numberDiff line numberDiff line change
@@ -3388,6 +3388,38 @@ Introspection helpers
33883388

33893389
.. versionadded:: 3.8
33903390

3391+
.. function:: get_protocol_members(tp)
3392+
3393+
Return the set of members defined in a :class:`Protocol`.
3394+
3395+
::
3396+
3397+
>>> from typing import Protocol, get_protocol_members
3398+
>>> class P(Protocol):
3399+
... def a(self) -> str: ...
3400+
... b: int
3401+
>>> get_protocol_members(P)
3402+
frozenset({'a', 'b'})
3403+
3404+
Raise :exc:`TypeError` for arguments that are not Protocols.
3405+
3406+
.. versionadded:: 3.13
3407+
3408+
.. function:: is_protocol(tp)
3409+
3410+
Determine if a type is a :class:`Protocol`.
3411+
3412+
For example::
3413+
3414+
class P(Protocol):
3415+
def a(self) -> str: ...
3416+
b: int
3417+
3418+
is_protocol(P) # => True
3419+
is_protocol(int) # => False
3420+
3421+
.. versionadded:: 3.13
3422+
33913423
.. function:: is_typeddict(tp)
33923424

33933425
Check if a type is a :class:`TypedDict`.

Doc/whatsnew/3.13.rst

+8
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ traceback
120120
to format the nested exceptions of a :exc:`BaseExceptionGroup` instance, recursively.
121121
(Contributed by Irit Katriel in :gh:`105292`.)
122122

123+
typing
124+
------
125+
126+
* Add :func:`typing.get_protocol_members` to return the set of members
127+
defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to
128+
check whether a class is a :class:`typing.Protocol`. (Contributed by Jelle Zijlstra in
129+
:gh:`104873`.)
130+
123131
Optimizations
124132
=============
125133

Lib/test/test_typing.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from typing import Generic, ClassVar, Final, final, Protocol
2525
from typing import assert_type, cast, runtime_checkable
2626
from typing import get_type_hints
27-
from typing import get_origin, get_args
27+
from typing import get_origin, get_args, get_protocol_members
2828
from typing import override
29-
from typing import is_typeddict
29+
from typing import is_typeddict, is_protocol
3030
from typing import reveal_type
3131
from typing import dataclass_transform
3232
from typing import no_type_check, no_type_check_decorator
@@ -3363,6 +3363,18 @@ def meth(self): pass
33633363
self.assertNotIn("__callable_proto_members_only__", vars(NonP))
33643364
self.assertNotIn("__callable_proto_members_only__", vars(NonPR))
33653365

3366+
self.assertEqual(get_protocol_members(P), {"x"})
3367+
self.assertEqual(get_protocol_members(PR), {"meth"})
3368+
3369+
# the returned object should be immutable,
3370+
# and should be a different object to the original attribute
3371+
# to prevent users from (accidentally or deliberately)
3372+
# mutating the attribute on the original class
3373+
self.assertIsInstance(get_protocol_members(P), frozenset)
3374+
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
3375+
self.assertIsInstance(get_protocol_members(PR), frozenset)
3376+
self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__)
3377+
33663378
acceptable_extra_attrs = {
33673379
'_is_protocol', '_is_runtime_protocol', '__parameters__',
33683380
'__init__', '__annotations__', '__subclasshook__',
@@ -3778,6 +3790,59 @@ def __init__(self):
37783790

37793791
Foo() # Previously triggered RecursionError
37803792

3793+
def test_get_protocol_members(self):
3794+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3795+
get_protocol_members(object)
3796+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3797+
get_protocol_members(object())
3798+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3799+
get_protocol_members(Protocol)
3800+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3801+
get_protocol_members(Generic)
3802+
3803+
class P(Protocol):
3804+
a: int
3805+
def b(self) -> str: ...
3806+
@property
3807+
def c(self) -> int: ...
3808+
3809+
self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
3810+
self.assertIsInstance(get_protocol_members(P), frozenset)
3811+
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
3812+
3813+
class Concrete:
3814+
a: int
3815+
def b(self) -> str: return "capybara"
3816+
@property
3817+
def c(self) -> int: return 5
3818+
3819+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3820+
get_protocol_members(Concrete)
3821+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3822+
get_protocol_members(Concrete())
3823+
3824+
class ConcreteInherit(P):
3825+
a: int = 42
3826+
def b(self) -> str: return "capybara"
3827+
@property
3828+
def c(self) -> int: return 5
3829+
3830+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3831+
get_protocol_members(ConcreteInherit)
3832+
with self.assertRaisesRegex(TypeError, "not a Protocol"):
3833+
get_protocol_members(ConcreteInherit())
3834+
3835+
def test_is_protocol(self):
3836+
self.assertTrue(is_protocol(Proto))
3837+
self.assertTrue(is_protocol(Point))
3838+
self.assertFalse(is_protocol(Concrete))
3839+
self.assertFalse(is_protocol(Concrete()))
3840+
self.assertFalse(is_protocol(Generic))
3841+
self.assertFalse(is_protocol(object))
3842+
3843+
# Protocol is not itself a protocol
3844+
self.assertFalse(is_protocol(Protocol))
3845+
37813846
def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self):
37823847
# Ensure the cache is empty, or this test won't work correctly
37833848
collections.abc.Sized._abc_registry_clear()

Lib/typing.py

+42
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@
131131
'get_args',
132132
'get_origin',
133133
'get_overloads',
134+
'get_protocol_members',
134135
'get_type_hints',
136+
'is_protocol',
135137
'is_typeddict',
136138
'LiteralString',
137139
'Never',
@@ -3337,3 +3339,43 @@ def method(self) -> None:
33373339
# read-only property, TypeError if it's a builtin class.
33383340
pass
33393341
return method
3342+
3343+
3344+
def is_protocol(tp: type, /) -> bool:
3345+
"""Return True if the given type is a Protocol.
3346+
3347+
Example::
3348+
3349+
>>> from typing import Protocol, is_protocol
3350+
>>> class P(Protocol):
3351+
... def a(self) -> str: ...
3352+
... b: int
3353+
>>> is_protocol(P)
3354+
True
3355+
>>> is_protocol(int)
3356+
False
3357+
"""
3358+
return (
3359+
isinstance(tp, type)
3360+
and getattr(tp, '_is_protocol', False)
3361+
and tp != Protocol
3362+
)
3363+
3364+
3365+
def get_protocol_members(tp: type, /) -> frozenset[str]:
3366+
"""Return the set of members defined in a Protocol.
3367+
3368+
Example::
3369+
3370+
>>> from typing import Protocol, get_protocol_members
3371+
>>> class P(Protocol):
3372+
... def a(self) -> str: ...
3373+
... b: int
3374+
>>> get_protocol_members(P)
3375+
frozenset({'a', 'b'})
3376+
3377+
Raise a TypeError for arguments that are not Protocols.
3378+
"""
3379+
if not is_protocol(tp):
3380+
raise TypeError(f'{tp!r} is not a Protocol')
3381+
return frozenset(tp.__protocol_attrs__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Add :func:`typing.get_protocol_members` to return the set of members
2+
defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to
3+
check whether a class is a :class:`typing.Protocol`. Patch by Jelle Zijlstra.

0 commit comments

Comments
 (0)