From d3863ed6019cef35e782bbff7db6a2363b42b1c9 Mon Sep 17 00:00:00 2001 From: erjia Date: Wed, 15 Feb 2023 22:09:07 +0000 Subject: [PATCH] Handle terminate --- torchdata/datapipes/iter/util/prefetcher.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 961f250de..1d7883c6c 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -88,11 +88,14 @@ def __iter__(self): thread.start() self.thread = thread + # Lazily import to prevent circular import + from torchdata.dataloader2 import communication + while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: if len(prefetch_data.prefetch_buffer) > 0: data = prefetch_data.prefetch_buffer.popleft() if isinstance(data, Exception): - if isinstance(data, StopIteration): + if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): break raise data yield data @@ -227,11 +230,14 @@ def __iter__(self): thread.start() self.thread = thread + # Lazily import to prevent circular import + from torchdata.dataloader2 import communication + while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: if len(prefetch_data.prefetch_buffer) > 0: data = prefetch_data.prefetch_buffer.popleft() if isinstance(data, Exception): - if isinstance(data, StopIteration): + if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): break raise data yield data