Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Better Exception Handling for Operators #9681

Merged
merged 22 commits into from
Feb 13, 2018
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class MXNET_API Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr, bool wait = false) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
Expand Down Expand Up @@ -182,7 +182,7 @@ class MXNET_API Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr, bool wait = false) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Add \param for wait in doc.
  2. Put bool wait... in the next line to keep the coding style consistent with the existing context. Same all the following changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed here and other places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wait? Why do you need it? Please document arguments

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait is used to indicate to the ExecuteOprBlock whether it is a waitforvar operation. it should not block the execution of the operator for WaitForVar.

/*!
* \brief Schedule the deletion of a variable.
*
Expand Down
4 changes: 2 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr, bool wait = false) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in the next line.

NaiveOpr *opr = new NaiveOpr();
opr->fn = fn;
opr->const_vars = const_vars;
Expand Down Expand Up @@ -125,7 +125,7 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr, bool wait = false) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next line.

CallbackOnComplete callback = CreateCallback(
NaiveEngine::OnComplete, nullptr);
this->req_completed_ = false;
Expand Down
45 changes: 39 additions & 6 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,14 @@ ThreadedOpr* ThreadedEngine::NewOperator(
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
const char* opr_name) {
const char* opr_name, bool wait) {
auto ret = ThreadedOpr::New();
ret->opr_name = opr_name;
ret->fn = std::move(fn);
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
ret->wait = wait;
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
Expand Down Expand Up @@ -305,9 +306,9 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
int priority,
const char* opr_name) {
const char* opr_name, bool wait) {
BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
#if MXNET_USE_PROFILER
Profiler *profiler = Profiler::Get();
Expand Down Expand Up @@ -356,7 +357,11 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
void ThreadedEngine::WaitForVar(VarHandle var) {
BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) return;
if (threaded_var->ready_to_read()) {
if (threaded_var->ex_ptr) {
std::rethrow_exception(threaded_var->ex_ptr);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you not returning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

if (engine_info_) {
LOG(INFO) << "Wait for " << threaded_var;
debug_wait_var_ = threaded_var;
Expand All @@ -376,13 +381,17 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}
on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("WaitForVar"));
PROFILER_MESSAGE("WaitForVar"), true);
{
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this, &done]() {
return done.load() || kill_.load();
});
}

if (threaded_var->ex_ptr) {
std::rethrow_exception(threaded_var->ex_ptr);
}
}

void ThreadedEngine::WaitForAll() {
Expand All @@ -391,6 +400,11 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
if (global_ex_ptr) {
std::exception_ptr ex_ptr = global_ex_ptr;
Copy link
Contributor

@piiswrong piiswrong Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::rethrow_exception(std::move(global_ex_ptr))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rolling this back, Since the state of global_ex_ptr is not guaranteed to be nullptr after the move and it depends on the implementation. This probably explains why it started failing on windows, after the change. Please let me know if you have any concerns.

global_ex_ptr = nullptr;
std::rethrow_exception(ex_ptr);
}
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
Expand All @@ -403,7 +417,11 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (threaded_opr->ex_ptr) {
i->ex_ptr = threaded_opr->ex_ptr;
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr;
}
Copy link
Contributor

@piiswrong piiswrong Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suppose and operator has three outputs x, y and z and it raises an exception.
then x.asnumpy() would raise an error.
Then y.asnumpy() would raise the same error again.

and if I do z += 1 and it succeeds, z.asnumpy() would still raise the error

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

z += 1 won't execute since z already has an exception associated with it. z.asnumpy() will still raise the error.

bool debug_info = (engine_info_ && debug_wait_var_ == i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove const?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not intentional. Added it back.

if (debug_info) {
LOG(INFO) << "Complete write dep for " << i;
}
Expand Down Expand Up @@ -443,6 +461,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

inline void ThreadedEngine::OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->ex_ptr) {
threaded_opr->ex_ptr = i->ex_ptr;
}
}
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand All @@ -457,5 +483,12 @@ void ThreadedEngine::OnCompleteStatic(
OprBlock::Delete(opr_block);
}

void ThreadedEngine::OnStartStatic(
Engine *engine, void *opr_block_) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
ThreadedOpr *threaded_opr = opr_block->opr;
static_cast<ThreadedEngine*>(engine)->OnStart(threaded_opr);
}

} // namespace engine
} // namespace mxnet
60 changes: 47 additions & 13 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::exception_ptr ex_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex_ptr is a bad variable name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i changed ex_ptr for var to var_exception, ex_ptr for opr to opr_exception and the global exception_ptr global_exception_ . I welcome if you have any other suggestions for naming them.


private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr,
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a wait operation like WaitForVar
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
Expand All @@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::exception_ptr ex_ptr;
}; // struct ThreadedOpr

/*!
Expand All @@ -265,15 +273,15 @@ class ThreadedEngine : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr, bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr, bool wait = false) override;
void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
Expand Down Expand Up @@ -338,33 +346,46 @@ class ThreadedEngine : public Engine {
#endif
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, opr_block);
CallbackOnComplete on_start_callback = this->CreateCallback(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of creating a call back here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback is not strictly necessary here, since it is called only once, but I included it to keep the ExecuteOprBlock easier to read and separate out the OnStart logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary overhead. Call OnStart directly if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Called OnStart directly.

ThreadedEngine::OnStartStatic, opr_block);
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block
<< "shutdown_phase=" << shutdown_phase_;
}
if (!shutdown_phase_) {
try {
on_start_callback();
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
threaded_opr->fn(run_ctx, callback);
try {
if (!threaded_opr->ex_ptr || threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only catch dmlc error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I catch only dmlc::Error is that The guards in the c_api API_BEGIN and API_END/API_END_HANDLE_ERROR only catch dmlc::Error currently and propagate to frontend.

Copy link
Member Author

@anirudh2290 anirudh2290 Feb 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used dmlc::Error in the inner block and used std::exception to catch other stdlib exceptions thrown in the outer block. Currently, it will catch only dmlc::Error and for other exceptions(std::exception) the process will be terminated. I will open another PR to handle std::exception and change the c_api guards and frontend code.

threaded_opr->ex_ptr = std::current_exception();
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch(dmlc::Error &e) {
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would this catch now? There is already a try block inside

Copy link
Member Author

@anirudh2290 anirudh2290 Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong my intention for the outer block was to catch all other exceptions which are not caught in the inner block. I should change dmlc::Error to std::exception to catch all standard exception. But you make a good point about propagating all the exceptions and not just dmlc::Error to the frontend. We can take one of the two approaches here: 1. catch dmlc::Error and terminate the process for everything else. 2. catch both dmlc::Error and standard exceptions and propagate to frontend. This will require code changes to the guards , c_api_error and potentially frontend code.

std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos &&
!shutdown_phase_) {
LOG(FATAL) << e.what() << "\n" <<
"A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
LOG(FATAL)
<< e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
Expand Down Expand Up @@ -414,8 +435,18 @@ class ThreadedEngine : public Engine {
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
/*!
* \brief Callback before operation start.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr);
// callback to the threaded engine
static void OnCompleteStatic(Engine *engine, void *threaded_opr);
// callback to mark exceptions if required before
// operator execution
static void OnStartStatic(Engine* engine, void *threaded_opr);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down Expand Up @@ -476,6 +507,9 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief exception_ptr associated with the engine,
* which is used to throw exception in waitall */
std::exception_ptr global_ex_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex_ptr is a bad name. class members should end with _.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed!


/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
Expand Down
4 changes: 2 additions & 2 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) throw std::bad_alloc();
if (ptr == NULL) LOG(FATAL) << "Malloc failure";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed to allocate CPU memory

#else
int ret = posix_memalign(&ptr, alignment_, size);
if (ret != 0) throw std::bad_alloc();
if (ret != 0) LOG(FATAL) << "Malloc failure";
#endif
return ptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
throw std::bad_alloc();
LOG(FATAL) << cudaGetLastError();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does cudaGetLastError return a string or an error code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have changed it to call cudaGetErrorString

#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from test_random import *
from test_gluon import *
from test_loss import *
from test_exc_handling import *
#from test_rnn import *
from test_gluon_rnn import *
from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
Expand Down
Loading