-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GeneratorEnqueuer Multithreading on Windows (and Linux...) #8662
Changes from 3 commits
c01d02e
baa54c0
28308f9
a9934f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -612,64 +612,97 @@ def __init__(self, generator, | |
seed=None): | ||
self.wait_time = wait_time | ||
self._generator = generator | ||
self._use_multiprocessing = use_multiprocessing | ||
if os.name is 'nt' and use_multiprocessing is True: | ||
# On Windows, avoid **SYSTEMATIC** error in `multiprocessing`: | ||
# `TypeError: can't pickle generator objects` | ||
# => Suggest multithreading instead of multiprocessing on Windows | ||
raise ValueError('Using a generator with `use_multiprocessing=True`' | ||
' is not supported on Windows (no marshalling of' | ||
' generators across process boundaries). Instead,' | ||
' use single thread/process or multithreading.') | ||
else: | ||
self._use_multiprocessing = use_multiprocessing | ||
self._threads = [] | ||
self._stop_event = None | ||
self._manager = None | ||
self.queue = None | ||
self.seed = seed | ||
|
||
def start(self, workers=1, max_queue_size=10): | ||
"""Kicks off threads which add data from the generator into the queue. | ||
|
||
# Arguments | ||
workers: number of worker threads | ||
max_queue_size: queue size | ||
(when full, threads could block on `put()`) | ||
""" | ||
|
||
def data_generator_task(): | ||
def __data_generator_task(self): | ||
if self._use_multiprocessing is False: | ||
while not self._stop_event.is_set(): | ||
with self.genlock: | ||
try: | ||
if self.queue is not None and self.queue.qsize() < self.max_queue_size: | ||
# On all OSes, avoid **SYSTEMATIC** error in multithreading mode: | ||
# `ValueError: generator already executing` | ||
# => Serialize calls to infinite iterator/generator's next() function | ||
generator_output = next(self._generator) | ||
self.queue.put((True, generator_output)) | ||
else: | ||
time.sleep(self.wait_time) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sleeping with a lock held seems bad. Shouldn't it be sufficient to guard this line with the lock:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my (admittedly limited) experience with generators, the thread is mostly waiting on getting an available slot in the queue rather than on getting samples from the generator. I guess this is all very dependent on whatever your use case scenario is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, it's not really a busy wait since its mostly sleeping. However, I have to agree that there isn't that much to be gained by holding the lock shorter. I think it would only matter if putting the object in the queue is a time consuming action (which it isn't for the threading case). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was also bothered by the sleep in the lock and thought to correct it. Even though the code can be refactored to avoid it, WE CANT lock only the This PR makes the queue to have fixed size. If we don't protect both the iterator and the addition to the queue, we can end up with threads being stack while trying to add to the queue. Because the stop() operation joins and waits for the threads to finish, this will not happen. To avoid sleeping in a lock, the function needs to be heavily refactored and to echo @philferriere might not be 100% worth it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the future, I'll just remove this use case. I suspect having I'll do some profiling and post my findings here. Also, we should try to mimic the Ordered Enqueuer so this class will get heavily refactored anytime soon anyway. (By the end of August) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With use_multiprocessing=False, how do you avoid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me provide some context and then move to benchmarks. Here are the key Keras releases for this discussion: Before the introduction of Sequences (<=2.0.5), all Generator classes in Keras implemented When I talk about thread-safe generators, I'm not refering to python generator methods but to Generator classes like those used by Keras before 2.0.5. These classes used an Iterator and made sure by using minimal locking that all the heavy computations are done ourside the lock. Since C libs can release the GIL, this gave speed improvements. When Sequences were introduced, many of the classes were rewritten and instead of Below I provide a snippet that can be ran with and without the lock (just comment out line 650). To produce something more than a toy example, I use the Snippet: import time
from keras.preprocessing import image
from keras.utils.data_utils import GeneratorEnqueuer
it = image.ImageDataGenerator().flow_from_directory('/path/to/images', target_size=(224, 224), batch_size=512)
reader = GeneratorEnqueuer(it, use_multiprocessing=False)
reader.start(workers=16, max_queue_size=16)
g = reader.get()
start = time.time()
n = 100
for i in range(n):
x=g.next()
total_time = time.time()-start
print('Total time %f sec; Average time: %f sec' % (total_time, total_time/n)) WITH LOCK: Total time 161.514681 sec; Average time: 1.615147 sec That's 5 times faster, provided that the underlying iterator is thread-safe. As you can see, In my opinion, removing the lock makes sense (or maybe checking the instance type and/or the existence of specific methods to understand if the lock is necessary?), because if you are using non-thread-safe Python generators you should not use If that's interesting to you, I can draft a PR so that you can check it out and decide if it's something you want to use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please draft a PR. I think we should keep a UX-friendly way of handling python generators. (ie. a good clear message stating that they should set workers=1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I'll put something together in the weekend and I'll tag you to get your feedback. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, I think it's a good idea to add some specific member variable to indicate that a generator is safe for running multi-threaded so it can be detected at runtime. That could then trigger a warning and enable the locking to serialize the |
||
except StopIteration: | ||
break | ||
except Exception as e: | ||
# Can't pickle tracebacks. | ||
# As a compromise, print the traceback and pickle None instead. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment was only valid for the multiprocessing branch. It can be removed here. |
||
if not hasattr(e, '__traceback__'): | ||
setattr(e, '__traceback__', sys.exc_info()[2]) | ||
self.queue.put((False, e)) | ||
self._stop_event.set() | ||
break | ||
else: | ||
while not self._stop_event.is_set(): | ||
try: | ||
if self._use_multiprocessing or self.queue.qsize() < max_queue_size: | ||
if self.queue is not None and self.queue.qsize() < self.max_queue_size: | ||
generator_output = next(self._generator) | ||
self.queue.put((True, generator_output)) | ||
else: | ||
time.sleep(self.wait_time) | ||
except StopIteration: | ||
break | ||
except Exception as e: | ||
# Can't pick tracebacks. | ||
# Can't pickle tracebacks. | ||
# As a compromise, print the traceback and pickle None instead. | ||
if self._use_multiprocessing: | ||
traceback.print_exc() | ||
setattr(e, '__traceback__', None) | ||
elif not hasattr(e, '__traceback__'): | ||
setattr(e, '__traceback__', sys.exc_info()[2]) | ||
traceback.print_exc() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove the traceback when using threads? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a question better asked to @de-vri-es who made commit 4a58b178073f0ba3b166220f7ebd7d56149bfb20 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the exception is put in the queue and rethrown in the main thread. Printing the trackback twice is no good. Also, if someone de decides to rethrow the exception you don't want the trackbacks printer at all. The only reason I kept the print in multiprocessing=True is that the trackback cant be pickled, and cant be put in an inter-process queue. Unconditionally printing the trackback is not a good idea in my opinion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant "if someone de decides to catch the exception you don't want the trackback printed at all" . I should 't do this from my phone apparantly... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like you undid my work in a merge conflict here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe so. The test for multiprocessing=True has been taken out of the while loop for clarity (there is now a code path for multiprocessing and one for multithreading). The diff displayed by Github appears misleading however. @de-vri-es, would you mind looking at the files side by side and tell me if you agree? I would greatly appreciate it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh sorry, I didn't see the full diff (I really shouldn't have continued on my phone). I see the |
||
setattr(e, '__traceback__', None) | ||
self.queue.put((False, e)) | ||
self._stop_event.set() | ||
break | ||
|
||
def start(self, workers=1, max_queue_size=10): | ||
"""Kicks off threads which add data from the generator into the queue. | ||
|
||
# Arguments | ||
workers: number of worker threads | ||
max_queue_size: queue size | ||
(when full, threads could block on `put()`) | ||
""" | ||
try: | ||
self.max_queue_size = max_queue_size | ||
if self._use_multiprocessing: | ||
self._manager = multiprocessing.Manager() | ||
self.queue = self._manager.Queue(maxsize=max_queue_size) | ||
self._stop_event = multiprocessing.Event() | ||
else: | ||
self.queue = queue.Queue() | ||
# On all OSes, avoid **SYSTEMATIC** error in multithreading mode: | ||
# `ValueError: generator already executing` | ||
# => Serialize calls to infinite iterator/generator's next() function | ||
self.genlock = threading.Lock() | ||
self.queue = queue.Queue(maxsize=max_queue_size) | ||
self._stop_event = threading.Event() | ||
|
||
for _ in range(workers): | ||
if self._use_multiprocessing: | ||
# Reset random seed else all children processes | ||
# share the same seed | ||
np.random.seed(self.seed) | ||
thread = multiprocessing.Process(target=data_generator_task) | ||
thread = multiprocessing.Process(target=self.__data_generator_task) | ||
thread.daemon = True | ||
if self.seed is not None: | ||
self.seed += 1 | ||
else: | ||
thread = threading.Thread(target=data_generator_task) | ||
thread = threading.Thread(target=self.__data_generator_task) | ||
self._threads.append(thread) | ||
thread.start() | ||
except: | ||
|
@@ -691,11 +724,15 @@ def stop(self, timeout=None): | |
self._stop_event.set() | ||
|
||
for thread in self._threads: | ||
if thread.is_alive(): | ||
if self._use_multiprocessing: | ||
if self._use_multiprocessing: | ||
if thread.is_alive(): | ||
thread.terminate() | ||
else: | ||
thread.join(timeout) | ||
else: | ||
# The thread.is_alive() test is subject to a race condition: | ||
# the thread could terminate right after the test and before the | ||
# join, rendering this test meaningless -> Call thread.join() | ||
# always, which is ok no matter what the status of the thread. | ||
thread.join(timeout) | ||
|
||
if self._manager: | ||
self._manager.shutdown() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need one leading underscore to make a method private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per Python's class module documentation:
The one-underscore approach assumes developers abiding by a coding convention and little else. Two underscores uses name mangling to enforce some level of privacy (easy to break, sure, but at least it is trying). Is the "by convention" approach the only one you need me to follow?
I just want to make sure. Thanks, @fchollet!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all private methods in the Keras codebase use a single leading underscore. Thanks!