Skip to content

Commit c1baa60

Browse files
committed
Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
1 parent dc2a613 commit c1baa60

File tree

6 files changed

+296
-56
lines changed

6 files changed

+296
-56
lines changed

Doc/library/weakref.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ than needed.
159159

160160
.. method:: WeakKeyDictionary.keyrefs()
161161

162-
Return an :term:`iterator` that yields the weak references to the keys.
162+
Return an iterable of the weak references to the keys.
163163

164164

165165
.. class:: WeakValueDictionary([dict])
@@ -182,7 +182,7 @@ These method have the same issues as the and :meth:`keyrefs` method of
182182

183183
.. method:: WeakValueDictionary.valuerefs()
184184

185-
Return an :term:`iterator` that yields the weak references to the values.
185+
Return an iterable of the weak references to the values.
186186

187187

188188
.. class:: WeakSet([elements])

Lib/_weakrefset.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,61 @@
66

77
__all__ = ['WeakSet']
88

9+
10+
class _IterationGuard:
11+
# This context manager registers itself in the current iterators of the
12+
# weak container, such as to delay all removals until the context manager
13+
# exits.
14+
# This technique should be relatively thread-safe (since sets are).
15+
16+
def __init__(self, weakcontainer):
17+
# Don't create cycles
18+
self.weakcontainer = ref(weakcontainer)
19+
20+
def __enter__(self):
21+
w = self.weakcontainer()
22+
if w is not None:
23+
w._iterating.add(self)
24+
return self
25+
26+
def __exit__(self, e, t, b):
27+
w = self.weakcontainer()
28+
if w is not None:
29+
s = w._iterating
30+
s.remove(self)
31+
if not s:
32+
w._commit_removals()
33+
34+
935
class WeakSet:
1036
def __init__(self, data=None):
1137
self.data = set()
1238
def _remove(item, selfref=ref(self)):
1339
self = selfref()
1440
if self is not None:
15-
self.data.discard(item)
41+
if self._iterating:
42+
self._pending_removals.append(item)
43+
else:
44+
self.data.discard(item)
1645
self._remove = _remove
46+
# A list of keys to be removed
47+
self._pending_removals = []
48+
self._iterating = set()
1749
if data is not None:
1850
self.update(data)
1951

52+
def _commit_removals(self):
53+
l = self._pending_removals
54+
discard = self.data.discard
55+
while l:
56+
discard(l.pop())
57+
2058
def __iter__(self):
21-
for itemref in self.data:
22-
item = itemref()
23-
if item is not None:
24-
yield item
59+
with _IterationGuard(self):
60+
for itemref in self.data:
61+
item = itemref()
62+
if item is not None:
63+
yield item
2564

2665
def __len__(self):
2766
return sum(x() is not None for x in self.data)
@@ -34,15 +73,21 @@ def __reduce__(self):
3473
getattr(self, '__dict__', None))
3574

3675
def add(self, item):
76+
if self._pending_removals:
77+
self._commit_removals()
3778
self.data.add(ref(item, self._remove))
3879

3980
def clear(self):
81+
if self._pending_removals:
82+
self._commit_removals()
4083
self.data.clear()
4184

4285
def copy(self):
4386
return self.__class__(self)
4487

4588
def pop(self):
89+
if self._pending_removals:
90+
self._commit_removals()
4691
while True:
4792
try:
4893
itemref = self.data.pop()
@@ -53,17 +98,24 @@ def pop(self):
5398
return item
5499

55100
def remove(self, item):
101+
if self._pending_removals:
102+
self._commit_removals()
56103
self.data.remove(ref(item))
57104

58105
def discard(self, item):
106+
if self._pending_removals:
107+
self._commit_removals()
59108
self.data.discard(ref(item))
60109

61110
def update(self, other):
111+
if self._pending_removals:
112+
self._commit_removals()
62113
if isinstance(other, self.__class__):
63114
self.data.update(other.data)
64115
else:
65116
for element in other:
66117
self.add(element)
118+
67119
def __ior__(self, other):
68120
self.update(other)
69121
return self
@@ -82,11 +134,15 @@ def difference(self, other):
82134
__sub__ = difference
83135

84136
def difference_update(self, other):
137+
if self._pending_removals:
138+
self._commit_removals()
85139
if self is other:
86140
self.data.clear()
87141
else:
88142
self.data.difference_update(ref(item) for item in other)
89143
def __isub__(self, other):
144+
if self._pending_removals:
145+
self._commit_removals()
90146
if self is other:
91147
self.data.clear()
92148
else:
@@ -98,8 +154,12 @@ def intersection(self, other):
98154
__and__ = intersection
99155

100156
def intersection_update(self, other):
157+
if self._pending_removals:
158+
self._commit_removals()
101159
self.data.intersection_update(ref(item) for item in other)
102160
def __iand__(self, other):
161+
if self._pending_removals:
162+
self._commit_removals()
103163
self.data.intersection_update(ref(item) for item in other)
104164
return self
105165

@@ -127,11 +187,15 @@ def symmetric_difference(self, other):
127187
__xor__ = symmetric_difference
128188

129189
def symmetric_difference_update(self, other):
190+
if self._pending_removals:
191+
self._commit_removals()
130192
if self is other:
131193
self.data.clear()
132194
else:
133195
self.data.symmetric_difference_update(ref(item) for item in other)
134196
def __ixor__(self, other):
197+
if self._pending_removals:
198+
self._commit_removals()
135199
if self is other:
136200
self.data.clear()
137201
else:

Lib/test/test_weakref.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import collections
55
import weakref
66
import operator
7+
import contextlib
8+
import copy
79

810
from test import support
911

@@ -788,6 +790,10 @@ def __init__(self, arg):
788790
self.arg = arg
789791
def __repr__(self):
790792
return "<Object %r>" % self.arg
793+
def __eq__(self, other):
794+
if isinstance(other, Object):
795+
return self.arg == other.arg
796+
return NotImplemented
791797
def __lt__(self, other):
792798
if isinstance(other, Object):
793799
return self.arg < other.arg
@@ -935,6 +941,87 @@ def check_iters(self, dict):
935941
self.assertFalse(values,
936942
"itervalues() did not touch all values")
937943

944+
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
945+
n = len(dict)
946+
it = iter(getattr(dict, iter_name)())
947+
next(it) # Trigger internal iteration
948+
# Destroy an object
949+
del objects[-1]
950+
gc.collect() # just in case
951+
# We have removed either the first consumed object, or another one
952+
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
953+
del it
954+
# The removal has been committed
955+
self.assertEqual(len(dict), n - 1)
956+
957+
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
958+
# Check that we can explicitly mutate the weak dict without
959+
# interfering with delayed removal.
960+
# `testcontext` should create an iterator, destroy one of the
961+
# weakref'ed objects and then return a new key/value pair corresponding
962+
# to the destroyed object.
963+
with testcontext() as (k, v):
964+
self.assertFalse(k in dict)
965+
with testcontext() as (k, v):
966+
self.assertRaises(KeyError, dict.__delitem__, k)
967+
self.assertFalse(k in dict)
968+
with testcontext() as (k, v):
969+
self.assertRaises(KeyError, dict.pop, k)
970+
self.assertFalse(k in dict)
971+
with testcontext() as (k, v):
972+
dict[k] = v
973+
self.assertEqual(dict[k], v)
974+
ddict = copy.copy(dict)
975+
with testcontext() as (k, v):
976+
dict.update(ddict)
977+
self.assertEqual(dict, ddict)
978+
with testcontext() as (k, v):
979+
dict.clear()
980+
self.assertEqual(len(dict), 0)
981+
982+
def test_weak_keys_destroy_while_iterating(self):
983+
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
984+
dict, objects = self.make_weak_keyed_dict()
985+
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
986+
self.check_weak_destroy_while_iterating(dict, objects, 'items')
987+
self.check_weak_destroy_while_iterating(dict, objects, 'values')
988+
self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs')
989+
dict, objects = self.make_weak_keyed_dict()
990+
@contextlib.contextmanager
991+
def testcontext():
992+
try:
993+
it = iter(dict.items())
994+
next(it)
995+
# Schedule a key/value for removal and recreate it
996+
v = objects.pop().arg
997+
gc.collect() # just in case
998+
yield Object(v), v
999+
finally:
1000+
it = None # should commit all removals
1001+
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
1002+
1003+
def test_weak_values_destroy_while_iterating(self):
1004+
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
1005+
dict, objects = self.make_weak_valued_dict()
1006+
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
1007+
self.check_weak_destroy_while_iterating(dict, objects, 'items')
1008+
self.check_weak_destroy_while_iterating(dict, objects, 'values')
1009+
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
1010+
self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs')
1011+
dict, objects = self.make_weak_valued_dict()
1012+
@contextlib.contextmanager
1013+
def testcontext():
1014+
try:
1015+
it = iter(dict.items())
1016+
next(it)
1017+
# Schedule a key/value for removal and recreate it
1018+
k = objects.pop().arg
1019+
gc.collect() # just in case
1020+
yield k, Object(k)
1021+
finally:
1022+
it = None # should commit all removals
1023+
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
1024+
9381025
def test_make_weak_keyed_dict_from_dict(self):
9391026
o = Object(3)
9401027
dict = weakref.WeakKeyDictionary({o:364})

Lib/test/test_weakset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import warnings
1111
import collections
1212
from collections import UserString as ustr
13+
import gc
14+
import contextlib
1315

1416

1517
class Foo:
@@ -307,6 +309,54 @@ def test_eq(self):
307309
self.assertFalse(self.s == WeakSet([Foo]))
308310
self.assertFalse(self.s == 1)
309311

312+
def test_weak_destroy_while_iterating(self):
313+
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
314+
# Create new items to be sure no-one else holds a reference
315+
items = [ustr(c) for c in ('a', 'b', 'c')]
316+
s = WeakSet(items)
317+
it = iter(s)
318+
next(it) # Trigger internal iteration
319+
# Destroy an item
320+
del items[-1]
321+
gc.collect() # just in case
322+
# We have removed either the first consumed items, or another one
323+
self.assertIn(len(list(it)), [len(items), len(items) - 1])
324+
del it
325+
# The removal has been committed
326+
self.assertEqual(len(s), len(items))
327+
328+
def test_weak_destroy_and_mutate_while_iterating(self):
329+
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
330+
items = [ustr(c) for c in string.ascii_letters]
331+
s = WeakSet(items)
332+
@contextlib.contextmanager
333+
def testcontext():
334+
try:
335+
it = iter(s)
336+
next(it)
337+
# Schedule an item for removal and recreate it
338+
u = ustr(str(items.pop()))
339+
gc.collect() # just in case
340+
yield u
341+
finally:
342+
it = None # should commit all removals
343+
344+
with testcontext() as u:
345+
self.assertFalse(u in s)
346+
with testcontext() as u:
347+
self.assertRaises(KeyError, s.remove, u)
348+
self.assertFalse(u in s)
349+
with testcontext() as u:
350+
s.add(u)
351+
self.assertTrue(u in s)
352+
t = s.copy()
353+
with testcontext() as u:
354+
s.update(t)
355+
self.assertEqual(len(s), len(t))
356+
with testcontext() as u:
357+
s.clear()
358+
self.assertEqual(len(s), 0)
359+
310360

311361
def test_main(verbose=None):
312362
support.run_unittest(TestWeakSet)

0 commit comments

Comments
 (0)