Skip to content

Commit 7540a43

Browse files
gh-92261: Disallow iteration of Union (and other special forms) (GH-92262) (GH-92582)
(cherry picked from commit 4739997) Co-authored-by: Matthew Rahtz <matthew.rahtz@gmail.com>
1 parent 74c094d commit 7540a43

File tree

4 files changed

+72
-5
lines changed

4 files changed

+72
-5
lines changed

Lib/test/test_genericalias.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,5 +487,25 @@ def test_del_iter(self):
487487
del iter_x
488488

489489

490+
class TypeIterationTests(unittest.TestCase):
491+
_UNITERABLE_TYPES = (list, tuple)
492+
493+
def test_cannot_iterate(self):
494+
for test_type in self._UNITERABLE_TYPES:
495+
with self.subTest(type=test_type):
496+
expected_error_regex = "object is not iterable"
497+
with self.assertRaisesRegex(TypeError, expected_error_regex):
498+
iter(test_type)
499+
with self.assertRaisesRegex(TypeError, expected_error_regex):
500+
list(test_type)
501+
with self.assertRaisesRegex(TypeError, expected_error_regex):
502+
for _ in test_type:
503+
pass
504+
505+
def test_is_not_instance_of_iterable(self):
506+
for type_to_test in self._UNITERABLE_TYPES:
507+
self.assertNotIsInstance(type_to_test, Iterable)
508+
509+
490510
if __name__ == "__main__":
491511
unittest.main()

Lib/test/test_typing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7348,6 +7348,37 @@ def test_all_exported_names(self):
73487348
self.assertSetEqual(computed_all, actual_all)
73497349

73507350

7351+
class TypeIterationTests(BaseTestCase):
7352+
_UNITERABLE_TYPES = (
7353+
Any,
7354+
Union,
7355+
Union[str, int],
7356+
Union[str, T],
7357+
List,
7358+
Tuple,
7359+
Callable,
7360+
Callable[..., T],
7361+
Callable[[T], str],
7362+
Annotated,
7363+
Annotated[T, ''],
7364+
)
7365+
7366+
def test_cannot_iterate(self):
7367+
expected_error_regex = "object is not iterable"
7368+
for test_type in self._UNITERABLE_TYPES:
7369+
with self.subTest(type=test_type):
7370+
with self.assertRaisesRegex(TypeError, expected_error_regex):
7371+
iter(test_type)
7372+
with self.assertRaisesRegex(TypeError, expected_error_regex):
7373+
list(test_type)
7374+
with self.assertRaisesRegex(TypeError, expected_error_regex):
7375+
for _ in test_type:
7376+
pass
7377+
7378+
def test_is_not_instance_of_iterable(self):
7379+
for type_to_test in self._UNITERABLE_TYPES:
7380+
self.assertNotIsInstance(type_to_test, collections.abc.Iterable)
7381+
73517382

73527383
if __name__ == '__main__':
73537384
main()

Lib/typing.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,24 @@ def __deepcopy__(self, memo):
405405
return self
406406

407407

408+
class _NotIterable:
409+
"""Mixin to prevent iteration, without being compatible with Iterable.
410+
411+
That is, we could do:
412+
def __iter__(self): raise TypeError()
413+
But this would make users of this mixin duck type-compatible with
414+
collections.abc.Iterable - isinstance(foo, Iterable) would be True.
415+
416+
Luckily, we can instead prevent iteration by setting __iter__ to None, which
417+
is treated specially.
418+
"""
419+
420+
__iter__ = None
421+
422+
408423
# Internal indicator of special typing constructs.
409424
# See __doc__ instance attribute for specific docs.
410-
class _SpecialForm(_Final, _root=True):
425+
class _SpecialForm(_Final, _NotIterable, _root=True):
411426
__slots__ = ('_name', '__doc__', '_getitem')
412427

413428
def __init__(self, getitem):
@@ -1498,7 +1513,7 @@ def __iter__(self):
14981513
# 1 for List and 2 for Dict. It may be -1 if variable number of
14991514
# parameters are accepted (needs custom __getitem__).
15001515

1501-
class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
1516+
class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True):
15021517
def __init__(self, origin, nparams, *, inst=True, name=None):
15031518
if name is None:
15041519
name = origin.__name__
@@ -1541,7 +1556,7 @@ def __or__(self, right):
15411556
def __ror__(self, left):
15421557
return Union[left, self]
15431558

1544-
class _CallableGenericAlias(_GenericAlias, _root=True):
1559+
class _CallableGenericAlias(_NotIterable, _GenericAlias, _root=True):
15451560
def __repr__(self):
15461561
assert self._name == 'Callable'
15471562
args = self.__args__
@@ -1606,7 +1621,7 @@ def __getitem__(self, params):
16061621
return self.copy_with(params)
16071622

16081623

1609-
class _UnionGenericAlias(_GenericAlias, _root=True):
1624+
class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
16101625
def copy_with(self, params):
16111626
return Union[params]
16121627

@@ -2046,7 +2061,7 @@ def _proto_hook(other):
20462061
cls.__init__ = _no_init_or_replace_init
20472062

20482063

2049-
class _AnnotatedAlias(_GenericAlias, _root=True):
2064+
class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True):
20502065
"""Runtime representation of an annotated type.
20512066
20522067
At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix hang when trying to iterate over a ``typing.Union``.

0 commit comments

Comments
 (0)