@@ -398,21 +398,33 @@ def clear_overloads():
398
398
}
399
399
400
400
401
+ _EXCLUDED_ATTRS = {
402
+ "__abstractmethods__" , "__annotations__" , "__weakref__" , "_is_protocol" ,
403
+ "_is_runtime_protocol" , "__dict__" , "__slots__" , "__parameters__" ,
404
+ "__orig_bases__" , "__module__" , "_MutableMapping__marker" , "__doc__" ,
405
+ "__subclasshook__" , "__orig_class__" , "__init__" , "__new__" ,
406
+ }
407
+
408
+ if sys .version_info < (3 , 8 ):
409
+ _EXCLUDED_ATTRS |= {
410
+ "_gorg" , "__next_in_mro__" , "__extra__" , "__tree_hash__" , "__args__" ,
411
+ "__origin__"
412
+ }
413
+
414
+ if sys .version_info >= (3 , 9 ):
415
+ _EXCLUDED_ATTRS .add ("__class_getitem__" )
416
+
417
+ _EXCLUDED_ATTRS = frozenset (_EXCLUDED_ATTRS )
418
+
419
+
401
420
def _get_protocol_attrs (cls ):
402
421
attrs = set ()
403
422
for base in cls .__mro__ [:- 1 ]: # without object
404
423
if base .__name__ in ('Protocol' , 'Generic' ):
405
424
continue
406
425
annotations = getattr (base , '__annotations__' , {})
407
426
for attr in list (base .__dict__ .keys ()) + list (annotations .keys ()):
408
- if (not attr .startswith ('_abc_' ) and attr not in (
409
- '__abstractmethods__' , '__annotations__' , '__weakref__' ,
410
- '_is_protocol' , '_is_runtime_protocol' , '__dict__' ,
411
- '__args__' , '__slots__' ,
412
- '__next_in_mro__' , '__parameters__' , '__origin__' ,
413
- '__orig_bases__' , '__extra__' , '__tree_hash__' ,
414
- '__doc__' , '__subclasshook__' , '__init__' , '__new__' ,
415
- '__module__' , '_MutableMapping__marker' , '_gorg' )):
427
+ if (not attr .startswith ('_abc_' ) and attr not in _EXCLUDED_ATTRS ):
416
428
attrs .add (attr )
417
429
return attrs
418
430
@@ -468,11 +480,18 @@ def _caller(depth=2):
468
480
return None
469
481
470
482
471
- # 3.8+
472
- if hasattr (typing , 'Protocol' ):
483
+ # A bug in runtime-checkable protocols was fixed in 3.10+,
484
+ # but we backport it to all versions
485
+ if sys .version_info >= (3 , 10 ):
473
486
Protocol = typing .Protocol
474
- # 3.7
487
+ runtime_checkable = typing . runtime_checkable
475
488
else :
489
+ def _allow_reckless_class_checks (depth = 4 ):
490
+ """Allow instance and class checks for special stdlib modules.
491
+ The abc and functools modules indiscriminately call isinstance() and
492
+ issubclass() on the whole MRO of a user class, which may contain protocols.
493
+ """
494
+ return _caller (depth ) in {'abc' , 'functools' , None }
476
495
477
496
def _no_init (self , * args , ** kwargs ):
478
497
if type (self )._is_protocol :
@@ -484,11 +503,19 @@ class _ProtocolMeta(abc.ABCMeta):
484
503
def __instancecheck__ (cls , instance ):
485
504
# We need this method for situations where attributes are
486
505
# assigned in __init__.
487
- if ((not getattr (cls , '_is_protocol' , False ) or
506
+ is_protocol_cls = getattr (cls , "_is_protocol" , False )
507
+ if (
508
+ is_protocol_cls and
509
+ not getattr (cls , '_is_runtime_protocol' , False ) and
510
+ not _allow_reckless_class_checks (depth = 2 )
511
+ ):
512
+ raise TypeError ("Instance and class checks can only be used with"
513
+ " @runtime_checkable protocols" )
514
+ if ((not is_protocol_cls or
488
515
_is_callable_members_only (cls )) and
489
516
issubclass (instance .__class__ , cls )):
490
517
return True
491
- if cls . _is_protocol :
518
+ if is_protocol_cls :
492
519
if all (hasattr (instance , attr ) and
493
520
(not callable (getattr (cls , attr , None )) or
494
521
getattr (instance , attr ) is not None )
@@ -530,6 +557,7 @@ def meth(self) -> T:
530
557
"""
531
558
__slots__ = ()
532
559
_is_protocol = True
560
+ _is_runtime_protocol = False
533
561
534
562
def __new__ (cls , * args , ** kwds ):
535
563
if cls is Protocol :
@@ -581,12 +609,12 @@ def _proto_hook(other):
581
609
if not cls .__dict__ .get ('_is_protocol' , None ):
582
610
return NotImplemented
583
611
if not getattr (cls , '_is_runtime_protocol' , False ):
584
- if _caller ( depth = 3 ) in { 'abc' , 'functools' } :
612
+ if _allow_reckless_class_checks () :
585
613
return NotImplemented
586
614
raise TypeError ("Instance and class checks can only be used with"
587
615
" @runtime protocols" )
588
616
if not _is_callable_members_only (cls ):
589
- if _caller ( depth = 3 ) in { 'abc' , 'functools' } :
617
+ if _allow_reckless_class_checks () :
590
618
return NotImplemented
591
619
raise TypeError ("Protocols with non-method members"
592
620
" don't support issubclass()" )
@@ -625,12 +653,6 @@ def _proto_hook(other):
625
653
f' protocols, got { repr (base )} ' )
626
654
cls .__init__ = _no_init
627
655
628
-
629
- # 3.8+
630
- if hasattr (typing , 'runtime_checkable' ):
631
- runtime_checkable = typing .runtime_checkable
632
- # 3.7
633
- else :
634
656
def runtime_checkable (cls ):
635
657
"""Mark a protocol class as a runtime protocol, so that it
636
658
can be used with isinstance() and issubclass(). Raise TypeError
@@ -639,7 +661,10 @@ def runtime_checkable(cls):
639
661
This allows a simple-minded structural check very similar to the
640
662
one-offs in collections.abc such as Hashable.
641
663
"""
642
- if not isinstance (cls , _ProtocolMeta ) or not cls ._is_protocol :
664
+ if not (
665
+ (isinstance (cls , _ProtocolMeta ) or issubclass (cls , typing .Generic ))
666
+ and getattr (cls , "_is_protocol" , False )
667
+ ):
643
668
raise TypeError ('@runtime_checkable can be only applied to protocol classes,'
644
669
f' got { cls !r} ' )
645
670
cls ._is_runtime_protocol = True
0 commit comments