-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attach traceback to Exception & Test disatpching process #1036
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
import traceback | ||
|
||
|
||
class KeyErrorMessage(str): | ||
r"""str subclass that returns itself in repr""" | ||
|
||
def __repr__(self): | ||
return self | ||
|
||
|
||
class ExceptionWrapper: | ||
r""" | ||
Wraps an exception with traceback to communicate across threads/processes | ||
""" | ||
|
||
def __init__(self, exc_info=None, where: str = "in background"): | ||
if exc_info is None: | ||
exc_info = sys.exc_info() | ||
self.exc_type = exc_info[0] | ||
self.exc_msg = "".join(traceback.format_exception(*exc_info)) | ||
self.where = where | ||
|
||
def reraise(self): | ||
r""" | ||
Reraises the wrapped exception in the current thread/process | ||
""" | ||
# Format a message such as: "Caught ValueError in DataLoader worker | ||
# process 2. Original Traceback:", followed by the traceback. | ||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" | ||
if self.exc_type == KeyError: | ||
# KeyError calls repr() on its argument (usually a dict key). This | ||
# makes stack traces unreadable. It will not be changed in Python | ||
# (https://bugs.python.org/issue2651), so we work around it. | ||
msg = KeyErrorMessage(msg) | ||
elif getattr(self.exc_type, "message", None): | ||
# Some exceptions have first argument as non-str but explicitly | ||
# have message field | ||
raise self.exc_type(message=msg) | ||
try: | ||
exception = self.exc_type(msg) | ||
except TypeError: | ||
# If the exception takes multiple arguments, don't try to | ||
# instantiate since we don't know how to | ||
raise RuntimeError(msg) from None | ||
raise exception |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ def reset(self) -> None: | |
self.cnt = 0 | ||
|
||
|
||
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None): | ||
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name, call_on_process_init=None): | ||
r""" | ||
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes | ||
with the protocol server in a non-blocking manner. | ||
|
@@ -85,6 +85,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call | |
source_datapipe, | ||
req_queue, | ||
res_queue, | ||
name, | ||
blocking_request_get=False, | ||
reset_iterator_counter=reset_iterator_counter, | ||
) | ||
|
@@ -99,7 +100,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call | |
time.sleep(0) | ||
|
||
|
||
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_init=None): | ||
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, name, call_on_process_init=None): | ||
r""" | ||
Initialize with the given init function, set the appropriate pipe and protocol server type, and | ||
create a loop with the protocol server. | ||
|
@@ -112,14 +113,19 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_ | |
|
||
torch.set_num_threads(1) | ||
|
||
loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_request_get=True) | ||
loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, name, blocking_request_get=True) | ||
|
||
for _ in loop: | ||
pass | ||
|
||
|
||
def _create_datapipe_queue_loop( | ||
source_datapipe, req_queue, res_queue, blocking_request_get=True, reset_iterator_counter=None | ||
source_datapipe, | ||
req_queue, | ||
res_queue, | ||
name, | ||
blocking_request_get=True, | ||
reset_iterator_counter=None, | ||
): | ||
if isinstance(source_datapipe, IterDataPipe): | ||
pipe_type = communication.iter | ||
|
@@ -133,25 +139,26 @@ def _create_datapipe_queue_loop( | |
return pipe_type.DataPipeBehindQueues( | ||
source_datapipe, | ||
protocol_type(req_queue, res_queue), | ||
name=name, | ||
blocking_request_get=blocking_request_get, | ||
reset_iterator_counter=reset_iterator_counter, | ||
) | ||
|
||
|
||
def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_init=None): | ||
def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, name, call_on_process_init=None): | ||
r""" | ||
Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target, | ||
and returns ``(process, req_queue, res_queue)``. | ||
""" | ||
req_queue = multiprocessing_ctx.Queue() | ||
res_queue = multiprocessing_ctx.Queue() | ||
process = multiprocessing_ctx.Process( | ||
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_on_process_init) | ||
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, name, call_on_process_init) | ||
) | ||
return process, req_queue, res_queue | ||
|
||
|
||
def CreateThreadForDataPipeline(datapipe): | ||
def CreateThreadForDataPipeline(datapipe, name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forgot to ask before: if this function is used anywhere currently? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is any usage for sure. We can remove it in the future since we might rely on |
||
r""" | ||
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target, | ||
and returns ``(process, req_queue, res_queue, new_copied_datapipe)``. | ||
|
@@ -171,11 +178,13 @@ def CreateThreadForDataPipeline(datapipe): | |
else: | ||
raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe) | ||
|
||
process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True) | ||
process = threading.Thread( | ||
target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, name), daemon=True | ||
) | ||
return process, req_queue, res_queue, new_datapipe | ||
|
||
|
||
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): | ||
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, name): | ||
ejguan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r""" | ||
Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target, | ||
and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``. | ||
|
@@ -187,6 +196,6 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): | |
res_queues.append(multiprocessing_ctx.Queue()) | ||
|
||
process = multiprocessing_ctx.Process( | ||
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues) | ||
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, name) | ||
) | ||
return process, req_queues, res_queues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe check if DL2 is shutdown properly? Is that possible to test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, DL2 is not properly shutdown for now because we only raise Error in the main process. We haven't handled to shutdown other worker processes when on process has Error. This probably the same reason that it takes longer time to shutdown.
This is tracked in #969