Skip to content

Commit

Permalink
Fix sequence timeout deadlock (#13322)
Browse files Browse the repository at this point in the history
* Add a test for deadlock after sequence worker timeout

* Call task_done even if the task timeouted

* catch dead worker warning

* fix line length

* Increase deadlock detection timeout to prevent flakiness
  • Loading branch information
andreyz4k authored and fchollet committed Sep 15, 2019
1 parent 7869134 commit cf9595a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,16 @@ def get(self):
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
idx = future.idx
warnings.warn(
'The input {} could not be retrieved.'
' It could be because a worker has died.'.format(idx),
UserWarning)
inputs = self.sequence[idx]
finally:
self.queue.task_done()

if inputs is not None:
yield inputs
except Exception:
Expand Down
46 changes: 46 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import tarfile
import threading
import signal
import shutil
import zipfile
from itertools import cycle
Expand Down Expand Up @@ -211,6 +212,25 @@ def on_epoch_end(self):
pass


class SlowSequence(Sequence):
def __init__(self, shape, value=1.0):
self.shape = shape
self.inner = value
self.wait = True

def __getitem__(self, item):
if self.wait:
self.wait = False
time.sleep(40)
return np.ones(self.shape, dtype=np.uint32) * item * self.inner

def __len__(self):
return 10

def on_epoch_end(self):
pass


@threadsafe_generator
def create_generator_from_sequence_threads(ds):
for i in cycle(range(len(ds))):
Expand Down Expand Up @@ -335,6 +355,32 @@ def test_ordered_enqueuer_fail_threads():
next(gen_output)


def test_ordered_enqueuer_timeout_threads():
enqueuer = OrderedEnqueuer(SlowSequence([3, 10, 10, 3]),
use_multiprocessing=False)

def handler(signum, frame):
raise TimeoutError('Sequence deadlocked')

old = signal.signal(signal.SIGALRM, handler)
signal.setitimer(signal.ITIMER_REAL, 60)
with pytest.warns(UserWarning) as record:
enqueuer.start(5, 10)
gen_output = enqueuer.get()
for epoch_num in range(2):
acc = []
for i in range(10):
acc.append(next(gen_output)[0, 0, 0, 0])
assert acc == list(range(10)), 'Order was not keep in ' \
'OrderedEnqueuer with threads'
enqueuer.stop()
assert len(record) == 1
assert str(record[0].message) == 'The input 0 could not be retrieved. ' \
'It could be because a worker has died.'
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old)


@use_spawn
def test_on_epoch_end_processes():
enqueuer = OrderedEnqueuer(DummySequence([3, 10, 10, 3]),
Expand Down

0 comments on commit cf9595a

Please sign in to comment.