Skip to content

Commit 7e998c2

Browse files
authoredApr 12, 2023
Fix issue when non runtime_protocol does not raise TypeError (#132)
Backport of CPython PR 26067 (python/cpython#26067)
1 parent 25b0971 commit 7e998c2

File tree

3 files changed

+76
-28
lines changed

3 files changed

+76
-28
lines changed
 

‎CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
- Add `typing_extensions.Buffer`, a marker class for buffer types, as proposed
44
by PEP 688. Equivalent to `collections.abc.Buffer` in Python 3.12. Patch by
55
Jelle Zijlstra.
6+
- Backport [CPython PR 26067](https://github.com/python/cpython/pull/26067)
7+
(originally by Yurii Karabas), ensuring that `isinstance()` calls on
8+
protocols raise `TypeError` when the protocol is not decorated with
9+
`@runtime_checkable`. Patch by Alex Waygood.
610

711
# Release 4.5.0 (February 14, 2023)
812

‎src/test_typing_extensions.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,22 @@ class E(C, BP): pass
14211421
self.assertNotIsInstance(D(), E)
14221422
self.assertNotIsInstance(E(), D)
14231423

1424+
@skipUnless(
1425+
hasattr(typing, "Protocol"),
1426+
"Test is only relevant if typing.Protocol exists"
1427+
)
1428+
def test_runtimecheckable_on_typing_dot_Protocol(self):
1429+
@runtime_checkable
1430+
class Foo(typing.Protocol):
1431+
x: int
1432+
1433+
class Bar:
1434+
def __init__(self):
1435+
self.x = 42
1436+
1437+
self.assertIsInstance(Bar(), Foo)
1438+
self.assertNotIsInstance(object(), Foo)
1439+
14241440
def test_no_instantiation(self):
14251441
class P(Protocol): pass
14261442
with self.assertRaises(TypeError):
@@ -1829,11 +1845,7 @@ def meth(self):
18291845
self.assertTrue(P._is_protocol)
18301846
self.assertTrue(PR._is_protocol)
18311847
self.assertTrue(PG._is_protocol)
1832-
if hasattr(typing, 'Protocol'):
1833-
self.assertFalse(P._is_runtime_protocol)
1834-
else:
1835-
with self.assertRaises(AttributeError):
1836-
self.assertFalse(P._is_runtime_protocol)
1848+
self.assertFalse(P._is_runtime_protocol)
18371849
self.assertTrue(PR._is_runtime_protocol)
18381850
self.assertTrue(PG[int]._is_protocol)
18391851
self.assertEqual(typing_extensions._get_protocol_attrs(P), {'meth'})
@@ -1929,6 +1941,13 @@ class CustomProtocol(TestCase, Protocol):
19291941
class CustomContextManager(typing.ContextManager, Protocol):
19301942
pass
19311943

1944+
def test_non_runtime_protocol_isinstance_check(self):
1945+
class P(Protocol):
1946+
x: int
1947+
1948+
with self.assertRaisesRegex(TypeError, "@runtime_checkable"):
1949+
isinstance(1, P)
1950+
19321951
def test_no_init_same_for_different_protocol_implementations(self):
19331952
class CustomProtocolWithoutInitA(Protocol):
19341953
pass
@@ -3314,7 +3333,7 @@ def test_typing_extensions_defers_when_possible(self):
33143333
'is_typeddict',
33153334
}
33163335
if sys.version_info < (3, 10):
3317-
exclude |= {'get_args', 'get_origin'}
3336+
exclude |= {'get_args', 'get_origin', 'Protocol', 'runtime_checkable'}
33183337
if sys.version_info < (3, 11):
33193338
exclude |= {'final', 'NamedTuple', 'Any'}
33203339
for item in typing_extensions.__all__:

‎src/typing_extensions.py

+47-22
Original file line numberDiff line numberDiff line change
@@ -398,21 +398,33 @@ def clear_overloads():
398398
}
399399

400400

401+
_EXCLUDED_ATTRS = {
402+
"__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol",
403+
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
404+
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
405+
"__subclasshook__", "__orig_class__", "__init__", "__new__",
406+
}
407+
408+
if sys.version_info < (3, 8):
409+
_EXCLUDED_ATTRS |= {
410+
"_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__",
411+
"__origin__"
412+
}
413+
414+
if sys.version_info >= (3, 9):
415+
_EXCLUDED_ATTRS.add("__class_getitem__")
416+
417+
_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS)
418+
419+
401420
def _get_protocol_attrs(cls):
402421
attrs = set()
403422
for base in cls.__mro__[:-1]: # without object
404423
if base.__name__ in ('Protocol', 'Generic'):
405424
continue
406425
annotations = getattr(base, '__annotations__', {})
407426
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
408-
if (not attr.startswith('_abc_') and attr not in (
409-
'__abstractmethods__', '__annotations__', '__weakref__',
410-
'_is_protocol', '_is_runtime_protocol', '__dict__',
411-
'__args__', '__slots__',
412-
'__next_in_mro__', '__parameters__', '__origin__',
413-
'__orig_bases__', '__extra__', '__tree_hash__',
414-
'__doc__', '__subclasshook__', '__init__', '__new__',
415-
'__module__', '_MutableMapping__marker', '_gorg')):
427+
if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS):
416428
attrs.add(attr)
417429
return attrs
418430

@@ -468,11 +480,18 @@ def _caller(depth=2):
468480
return None
469481

470482

471-
# 3.8+
472-
if hasattr(typing, 'Protocol'):
483+
# A bug in runtime-checkable protocols was fixed in 3.10+,
484+
# but we backport it to all versions
485+
if sys.version_info >= (3, 10):
473486
Protocol = typing.Protocol
474-
# 3.7
487+
runtime_checkable = typing.runtime_checkable
475488
else:
489+
def _allow_reckless_class_checks(depth=4):
490+
"""Allow instance and class checks for special stdlib modules.
491+
The abc and functools modules indiscriminately call isinstance() and
492+
issubclass() on the whole MRO of a user class, which may contain protocols.
493+
"""
494+
return _caller(depth) in {'abc', 'functools', None}
476495

477496
def _no_init(self, *args, **kwargs):
478497
if type(self)._is_protocol:
@@ -484,11 +503,19 @@ class _ProtocolMeta(abc.ABCMeta):
484503
def __instancecheck__(cls, instance):
485504
# We need this method for situations where attributes are
486505
# assigned in __init__.
487-
if ((not getattr(cls, '_is_protocol', False) or
506+
is_protocol_cls = getattr(cls, "_is_protocol", False)
507+
if (
508+
is_protocol_cls and
509+
not getattr(cls, '_is_runtime_protocol', False) and
510+
not _allow_reckless_class_checks(depth=2)
511+
):
512+
raise TypeError("Instance and class checks can only be used with"
513+
" @runtime_checkable protocols")
514+
if ((not is_protocol_cls or
488515
_is_callable_members_only(cls)) and
489516
issubclass(instance.__class__, cls)):
490517
return True
491-
if cls._is_protocol:
518+
if is_protocol_cls:
492519
if all(hasattr(instance, attr) and
493520
(not callable(getattr(cls, attr, None)) or
494521
getattr(instance, attr) is not None)
@@ -530,6 +557,7 @@ def meth(self) -> T:
530557
"""
531558
__slots__ = ()
532559
_is_protocol = True
560+
_is_runtime_protocol = False
533561

534562
def __new__(cls, *args, **kwds):
535563
if cls is Protocol:
@@ -581,12 +609,12 @@ def _proto_hook(other):
581609
if not cls.__dict__.get('_is_protocol', None):
582610
return NotImplemented
583611
if not getattr(cls, '_is_runtime_protocol', False):
584-
if _caller(depth=3) in {'abc', 'functools'}:
612+
if _allow_reckless_class_checks():
585613
return NotImplemented
586614
raise TypeError("Instance and class checks can only be used with"
587615
" @runtime protocols")
588616
if not _is_callable_members_only(cls):
589-
if _caller(depth=3) in {'abc', 'functools'}:
617+
if _allow_reckless_class_checks():
590618
return NotImplemented
591619
raise TypeError("Protocols with non-method members"
592620
" don't support issubclass()")
@@ -625,12 +653,6 @@ def _proto_hook(other):
625653
f' protocols, got {repr(base)}')
626654
cls.__init__ = _no_init
627655

628-
629-
# 3.8+
630-
if hasattr(typing, 'runtime_checkable'):
631-
runtime_checkable = typing.runtime_checkable
632-
# 3.7
633-
else:
634656
def runtime_checkable(cls):
635657
"""Mark a protocol class as a runtime protocol, so that it
636658
can be used with isinstance() and issubclass(). Raise TypeError
@@ -639,7 +661,10 @@ def runtime_checkable(cls):
639661
This allows a simple-minded structural check very similar to the
640662
one-offs in collections.abc such as Hashable.
641663
"""
642-
if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol:
664+
if not (
665+
(isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic))
666+
and getattr(cls, "_is_protocol", False)
667+
):
643668
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
644669
f' got {cls!r}')
645670
cls._is_runtime_protocol = True

0 commit comments

Comments
 (0)
Please sign in to comment.