-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Better Exception Handling for Operators #9681
Changes from 12 commits
a366ef5
0d12ebe
a9728d6
fb5e4ef
8d59764
70e914f
a9a51e8
eb1166b
dcd5fab
c80d286
bcbe19a
51b0dfb
fe1fa4e
3bcc372
87da504
fa0d83d
bbf6702
320de0e
ec5d68b
04dd9de
f7349cb
5ae7743
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,7 +147,7 @@ class MXNET_API Engine { | |
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
const char* opr_name = nullptr) = 0; | ||
const char* opr_name = nullptr, bool wait = false) = 0; | ||
/*! | ||
* \brief Delete the given operator. | ||
* \param op The operator to delete. | ||
|
@@ -182,7 +182,7 @@ class MXNET_API Engine { | |
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
int priority = 0, | ||
const char* opr_name = nullptr) = 0; | ||
const char* opr_name = nullptr, bool wait = false) = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's wait? Why do you need it? Please document arguments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait is used to indicate to the ExecuteOprBlock whether it is a waitforvar operation. it should not block the execution of the operator for WaitForVar. |
||
/*! | ||
* \brief Schedule the deletion of a variable. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,7 +73,7 @@ class NaiveEngine final : public Engine { | |
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
const char* opr_name = nullptr) override { | ||
const char* opr_name = nullptr, bool wait = false) override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put in the next line. |
||
NaiveOpr *opr = new NaiveOpr(); | ||
opr->fn = fn; | ||
opr->const_vars = const_vars; | ||
|
@@ -125,7 +125,7 @@ class NaiveEngine final : public Engine { | |
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
int priority = 0, | ||
const char* opr_name = nullptr) override { | ||
const char* opr_name = nullptr, bool wait = false) override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Next line. |
||
CallbackOnComplete callback = CreateCallback( | ||
NaiveEngine::OnComplete, nullptr); | ||
this->req_completed_ = false; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,13 +206,14 @@ ThreadedOpr* ThreadedEngine::NewOperator( | |
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop, | ||
const char* opr_name) { | ||
const char* opr_name, bool wait) { | ||
auto ret = ThreadedOpr::New(); | ||
ret->opr_name = opr_name; | ||
ret->fn = std::move(fn); | ||
ret->prop = prop; | ||
ret->const_vars.resize(const_vars.size()); | ||
ret->mutable_vars.resize(mutable_vars.size()); | ||
ret->wait = wait; | ||
std::transform(const_vars.begin(), const_vars.end(), | ||
ret->const_vars.begin(), ThreadedVar::CastFromBase); | ||
std::transform(mutable_vars.begin(), mutable_vars.end(), | ||
|
@@ -305,9 +306,9 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, | |
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop, | ||
int priority, | ||
const char* opr_name) { | ||
const char* opr_name, bool wait) { | ||
BulkFlush(); | ||
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name); | ||
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait); | ||
opr->temporary = true; | ||
#if MXNET_USE_PROFILER | ||
Profiler *profiler = Profiler::Get(); | ||
|
@@ -356,7 +357,11 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn, | |
void ThreadedEngine::WaitForVar(VarHandle var) { | ||
BulkFlush(); | ||
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); | ||
if (threaded_var->ready_to_read()) return; | ||
if (threaded_var->ready_to_read()) { | ||
if (threaded_var->ex_ptr) { | ||
std::rethrow_exception(threaded_var->ex_ptr); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you not returning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! Fixed. |
||
if (engine_info_) { | ||
LOG(INFO) << "Wait for " << threaded_var; | ||
debug_wait_var_ = threaded_var; | ||
|
@@ -376,13 +381,17 @@ void ThreadedEngine::WaitForVar(VarHandle var) { | |
} | ||
on_complete(); | ||
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0, | ||
PROFILER_MESSAGE("WaitForVar")); | ||
PROFILER_MESSAGE("WaitForVar"), true); | ||
{ | ||
std::unique_lock<std::mutex> lock{finished_m_}; | ||
finished_cv_.wait(lock, [this, &done]() { | ||
return done.load() || kill_.load(); | ||
}); | ||
} | ||
|
||
if (threaded_var->ex_ptr) { | ||
std::rethrow_exception(threaded_var->ex_ptr); | ||
} | ||
} | ||
|
||
void ThreadedEngine::WaitForAll() { | ||
|
@@ -391,6 +400,11 @@ void ThreadedEngine::WaitForAll() { | |
finished_cv_.wait(lock, [this]() { | ||
return pending_.load() == 0 || kill_.load(); | ||
}); | ||
if (global_ex_ptr) { | ||
std::exception_ptr ex_ptr = global_ex_ptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use std::rethrow_exception(std::move(global_ex_ptr)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rolling this back, Since the state of global_ex_ptr is not guaranteed to be nullptr after the move and it depends on the implementation. This probably explains why it started failing on windows, after the change. Please let me know if you have any concerns. |
||
global_ex_ptr = nullptr; | ||
std::rethrow_exception(ex_ptr); | ||
} | ||
} | ||
|
||
inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { | ||
|
@@ -403,7 +417,11 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { | |
} | ||
// Mark complete for write variables. | ||
for (auto&& i : threaded_opr->mutable_vars) { | ||
const bool debug_info = (engine_info_ && debug_wait_var_ == i); | ||
if (threaded_opr->ex_ptr) { | ||
i->ex_ptr = threaded_opr->ex_ptr; | ||
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suppose and operator has three outputs x, y and z and it raises an exception. and if I do z += 1 and it succeeds, z.asnumpy() would still raise the error There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. z += 1 won't execute since z already has an exception associated with it. z.asnumpy() will still raise the error. |
||
bool debug_info = (engine_info_ && debug_wait_var_ == i); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why remove const? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was not intentional. Added it back. |
||
if (debug_info) { | ||
LOG(INFO) << "Complete write dep for " << i; | ||
} | ||
|
@@ -443,6 +461,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { | |
} | ||
} | ||
|
||
inline void ThreadedEngine::OnStart(ThreadedOpr* threaded_opr) { | ||
for (auto&& i : threaded_opr->const_vars) { | ||
if (i->ex_ptr) { | ||
threaded_opr->ex_ptr = i->ex_ptr; | ||
} | ||
} | ||
} | ||
|
||
void ThreadedEngine::OnCompleteStatic( | ||
Engine *engine, void *opr_block_) { | ||
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_); | ||
|
@@ -457,5 +483,12 @@ void ThreadedEngine::OnCompleteStatic( | |
OprBlock::Delete(opr_block); | ||
} | ||
|
||
void ThreadedEngine::OnStartStatic( | ||
Engine *engine, void *opr_block_) { | ||
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_); | ||
ThreadedOpr *threaded_opr = opr_block->opr; | ||
static_cast<ThreadedEngine*>(engine)->OnStart(threaded_opr); | ||
} | ||
|
||
} // namespace engine | ||
} // namespace mxnet |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -175,6 +175,8 @@ class ThreadedVar final | |
static std::atomic<std::size_t> counter; | ||
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } | ||
#endif // ENGINE_DEBUG | ||
/*! \brief exception_ptr associated with the ThreadedVar */ | ||
std::exception_ptr ex_ptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ex_ptr is a bad variable name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i changed ex_ptr for var to var_exception, ex_ptr for opr to opr_exception and the global exception_ptr global_exception_ . I welcome if you have any other suggestions for naming them. |
||
|
||
private: | ||
// TODO(hotpxl) change this to spinlock for faster runtime | ||
|
@@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr, | |
* that can be deleted right after the operation completed. | ||
*/ | ||
bool temporary{false}; | ||
/*! | ||
* \brief Whether this is a wait operation like WaitForVar | ||
*/ | ||
bool wait{false}; | ||
/*! | ||
* \brief Cast a Opr pointer to ThreadedOpr pointer | ||
* \param ptr pointer from base. | ||
|
@@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr, | |
} | ||
// define possible debug information | ||
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); | ||
/*! \brief exception_ptr associated with the ThreadedOpr */ | ||
std::exception_ptr ex_ptr; | ||
}; // struct ThreadedOpr | ||
|
||
/*! | ||
|
@@ -265,15 +273,15 @@ class ThreadedEngine : public Engine { | |
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
const char* opr_name = nullptr) override; | ||
const char* opr_name = nullptr, bool wait = false) override; | ||
void DeleteOperator(OprHandle op) override; | ||
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; | ||
void PushAsync(AsyncFn exec_fun, Context exec_ctx, | ||
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
int priority = 0, | ||
const char* opr_name = nullptr) override; | ||
const char* opr_name = nullptr, bool wait = false) override; | ||
void PushSync(SyncFn exec_fn, Context exec_ctx, | ||
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
|
@@ -338,33 +346,46 @@ class ThreadedEngine : public Engine { | |
#endif | ||
CallbackOnComplete callback = this->CreateCallback( | ||
ThreadedEngine::OnCompleteStatic, opr_block); | ||
CallbackOnComplete on_start_callback = this->CreateCallback( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the point of creating a call back here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The callback is not strictly necessary here, since it is called only once, but I included it to keep the ExecuteOprBlock easier to read and separate out the OnStart logic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is unnecessary overhead. Call OnStart directly if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Called OnStart directly. |
||
ThreadedEngine::OnStartStatic, opr_block); | ||
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); | ||
if (debug_info) { | ||
LOG(INFO) << "ExecuteOprBlock " << opr_block | ||
<< "shutdown_phase=" << shutdown_phase_; | ||
} | ||
if (!shutdown_phase_) { | ||
try { | ||
on_start_callback(); | ||
if (debug_info) { | ||
LOG(INFO) << "ExecuteOprFn "; | ||
} | ||
threaded_opr->fn(run_ctx, callback); | ||
try { | ||
if (!threaded_opr->ex_ptr || threaded_opr->wait) { | ||
threaded_opr->fn(run_ctx, callback); | ||
} else { | ||
callback(); | ||
} | ||
} catch (dmlc::Error& e) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why only catch dmlc error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I catch only dmlc::Error is that The guards in the c_api API_BEGIN and API_END/API_END_HANDLE_ERROR only catch dmlc::Error currently and propagate to frontend. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have used dmlc::Error in the inner block and used std::exception to catch other stdlib exceptions thrown in the outer block. Currently, it will catch only dmlc::Error and for other exceptions(std::exception) the process will be terminated. I will open another PR to handle std::exception and change the c_api guards and frontend code. |
||
threaded_opr->ex_ptr = std::current_exception(); | ||
callback(); | ||
} | ||
if (debug_info) { | ||
LOG(INFO) << "Fin ExecuteOprFn "; | ||
} | ||
} catch(dmlc::Error &e) { | ||
} catch (dmlc::Error& e) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what would this catch now? There is already a try block inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @piiswrong my intention for the outer block was to catch all other exceptions which are not caught in the inner block. I should change dmlc::Error to std::exception to catch all standard exception. But you make a good point about propagating all the exceptions and not just dmlc::Error to the frontend. We can take one of the two approaches here: 1. catch dmlc::Error and terminate the process for everything else. 2. catch both dmlc::Error and standard exceptions and propagate to frontend. This will require code changes to the guards , c_api_error and potentially frontend code. |
||
std::string what = e.what(); | ||
if (what.find("driver shutting down") == std::string::npos && | ||
!shutdown_phase_) { | ||
LOG(FATAL) << e.what() << "\n" << | ||
"A fatal error occurred in asynchronous engine operation. " | ||
"If you do not know what caused this error, " | ||
"you can try set environment variable MXNET_ENGINE_TYPE " | ||
"to NaiveEngine and run with debugger (i.e. gdb). " | ||
"This will force all operations to be synchronous and " | ||
"backtrace will give you the series of calls that lead " | ||
"to this error. Remember to set MXNET_ENGINE_TYPE back to " | ||
"empty after debugging."; | ||
LOG(FATAL) | ||
<< e.what() << "\n" | ||
<< "A fatal error occurred in asynchronous engine operation. " | ||
"If you do not know what caused this error, " | ||
"you can try set environment variable MXNET_ENGINE_TYPE " | ||
"to NaiveEngine and run with debugger (i.e. gdb). " | ||
"This will force all operations to be synchronous and " | ||
"backtrace will give you the series of calls that lead " | ||
"to this error. Remember to set MXNET_ENGINE_TYPE back to " | ||
"empty after debugging."; | ||
} | ||
} | ||
} else { | ||
|
@@ -414,8 +435,18 @@ class ThreadedEngine : public Engine { | |
* On operation completion, this will trigger subsequent operations. | ||
*/ | ||
inline void OnComplete(ThreadedOpr* threaded_opr); | ||
/*! | ||
* \brief Callback before operation start. | ||
* | ||
* Will mark the operator as a failure and associate exception_ptr | ||
* if any of the read dependencies have exception associated. | ||
*/ | ||
inline void OnStart(ThreadedOpr* threaded_opr); | ||
// callback to the threaded engine | ||
static void OnCompleteStatic(Engine *engine, void *threaded_opr); | ||
// callback to mark exceptions if required before | ||
// operator execution | ||
static void OnStartStatic(Engine* engine, void *threaded_opr); | ||
/*! \brief append an operator to bulk */ | ||
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, | ||
std::vector<VarHandle> const& const_vars, | ||
|
@@ -476,6 +507,9 @@ class ThreadedEngine : public Engine { | |
*/ | ||
std::mutex finished_m_; | ||
std::condition_variable finished_cv_; | ||
/*! \brief exception_ptr associated with the engine, | ||
* which is used to throw exception in waitall */ | ||
std::exception_ptr global_ex_ptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ex_ptr is a bad name. class members should end with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed! |
||
|
||
/*! | ||
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) { | |
void* ptr; | ||
#if _MSC_VER | ||
ptr = _aligned_malloc(size, alignment_); | ||
if (ptr == NULL) throw std::bad_alloc(); | ||
if (ptr == NULL) LOG(FATAL) << "Malloc failure"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Failed to allocate CPU memory |
||
#else | ||
int ret = posix_memalign(&ptr, alignment_, size); | ||
if (ret != 0) throw std::bad_alloc(); | ||
if (ret != 0) LOG(FATAL) << "Malloc failure"; | ||
#endif | ||
return ptr; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) { | |
#endif // MXNET_USE_NCCL | ||
cudaError_t e = cudaMalloc(&ret, size); | ||
if (e != cudaSuccess && e != cudaErrorCudartUnloading) | ||
throw std::bad_alloc(); | ||
LOG(FATAL) << cudaGetLastError(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does cudaGetLastError return a string or an error code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! I have changed it to call cudaGetErrorString |
||
#else // MXNET_USE_CUDA | ||
LOG(FATAL) << "Please compile with CUDA enabled"; | ||
#endif // MXNET_USE_CUDA | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\param
forwait
in doc.bool wait...
in the next line to keep the coding style consistent with the existing context. Same all the following changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed here and other places.