From 456ca1f400c37976ad09840a8a0f356164230f72 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 8 Apr 2019 00:21:36 -0700 Subject: [PATCH] Add exception handling support for waitall (#14397) * Relax constexpr restriction * Change imagenet_gen_qsym_mkldnn * Add exception handling support for waitall * Fix exception handling documentation * Revert constexpr change * Add comments * Fix test * Skip exception for op check names * Print exceptions thrown for CPP Package NDArray module * Reducing batch_size to make cpp-package example pass * Fix bug: #14426 * use ExceptionRef in threaded_engine code * add note for performance impact of waitall * Add check for GPU contxt * Use range for with const reference * Improve comments and error message for exception handling test * Change exception_ptr name in waitall * Fix bug --- cpp-package/example/resnet.cpp | 2 +- cpp-package/include/mxnet-cpp/ndarray.hpp | 6 +- docs/architecture/exception_handling.md | 3 - python/mxnet/ndarray/ndarray.py | 7 +- src/engine/threaded_engine.cc | 20 ++++ src/engine/threaded_engine.h | 36 ++++++- src/resource.cc | 14 +-- tests/python/unittest/test_exc_handling.py | 113 +++++++++++++++------ tests/python/unittest/test_operator.py | 14 ++- 9 files changed, 159 insertions(+), 56 deletions(-) diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp index f59f60679544..8f8fd12e32ce 100644 --- a/cpp-package/example/resnet.cpp +++ b/cpp-package/example/resnet.cpp @@ -185,7 +185,7 @@ int main(int argc, char const *argv[]) { #if !MXNET_USE_CPU if (num_gpu > 0) { ctx = Context::gpu(); - batch_size = 50; + batch_size = 32; } #endif diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 966cf75c9122..b667542bffb5 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) const { return NDArray(handle); } inline void NDArray::WaitToRead() const { - CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0); + CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError(); } inline void NDArray::WaitToWrite() { - CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0); + CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError(); } -inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } +inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << MXGetLastError(); } inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { Operator("_random_normal")(mu, sigma).Invoke(*out); } diff --git a/docs/architecture/exception_handling.md b/docs/architecture/exception_handling.md index 6a9ab9ae0c4c..87481bcdb9bd 100644 --- a/docs/architecture/exception_handling.md +++ b/docs/architecture/exception_handling.md @@ -123,6 +123,3 @@ except mx.base.MXNetError as ex: d.asnumpy() ``` -### Limitation - -Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your code executes a few operators and then calls `waitall` instead of `wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls in your code unless you are confident about your code not throwing exception in any scenario. diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index acb7b283aa76..87f2712d8a40 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -158,12 +158,9 @@ def waitall(): This function is used for benchmarking only. - .. warning:: + .. note:: - If your code has exceptions, `waitall` can cause silent failures. - For this reason you should avoid `waitall` in your code. - Use it only if you are confident that your code is error free. - Then make sure you call `wait_to_read` on all outputs after `waitall`. + If your mxnet code throws an exception, then waitall can cause performance impact. """ check_call(_LIB.MXNDArrayWaitAll()) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index b5897a1ca9cd..986b6ad29909 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -415,6 +415,23 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); + std::exception_ptr exception_to_rethrow = nullptr; + if (!global_exception_refs_.empty()) { + // iterate through all exception refs + for (const auto& global_exception_ref : global_exception_refs_) { + // the first exception will be saved to be rethrown later + if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) { + exception_to_rethrow = *global_exception_ref; + } + // clear exceptions, WaitToRead following WaitForAll shouldn't throw + *global_exception_ref = nullptr; + } + // A waitall following a waitall shouldn't throw any exceptions + global_exception_refs_.clear(); + if (exception_to_rethrow != nullptr) { + std::rethrow_exception(exception_to_rethrow); + } + } } inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { @@ -428,6 +445,9 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { for (auto&& i : threaded_opr->mutable_vars) { if (threaded_opr->opr_exception && *threaded_opr->opr_exception) { i->var_exception = threaded_opr->opr_exception; + // add current operator exceptions to global exceptions if not already + // added + AddToGlobalExceptions(threaded_opr->opr_exception); } const bool debug_info = (engine_info_ && debug_wait_var_ == i); if (debug_info) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 640eac4de086..3d2119d63291 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -60,6 +60,9 @@ namespace engine { // Forward declarations struct ThreadedOpr; +/*! shared_ptr to exception_ptr, used for exception handling */ +typedef std::shared_ptr ExceptionRef; + /*! * \brief Operation block in the scheduler. * Each OprBlock corresponds to an operation pushed to the engine. @@ -177,8 +180,12 @@ class ThreadedVar final static std::atomic counter; ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } #endif // ENGINE_DEBUG - /*! \brief exception_ptr associated with the ThreadedVar */ - std::shared_ptr var_exception; + /*! + * \brief exception_ptr associated with the ThreadedOpr + * cannot modify state of exception object since dereferencing + * exception_ptr is undefined behavior. Using shared_ptr to hold + * exception_ptr and overcome this limitation */ + ExceptionRef var_exception; private: // TODO(hotpxl) change this to spinlock for faster runtime @@ -254,8 +261,12 @@ struct ThreadedOpr final : public Opr, } // define possible debug information DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); - /*! \brief exception_ptr associated with the ThreadedOpr */ - std::shared_ptr opr_exception; + /*! + * \brief exception_ptr associated with the ThreadedOpr + * cannot modify state of exception object since dereferencing + * exception_ptr is undefined behavior. Using shared_ptr to hold + * exception_ptr and overcome this limitation */ + ExceptionRef opr_exception; }; // struct ThreadedOpr /*! @@ -432,6 +443,7 @@ class ThreadedEngine : public Engine { }; /*! thread local store for bulk */ typedef dmlc::ThreadLocalStore BulkStatusStore; + /*! * \brief check if thee is duplication in const_vars and mutable_vars. * \param const_vars the variables to read from. @@ -460,6 +472,7 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->const_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; + AddToGlobalExceptions(threaded_opr->opr_exception); break; } } @@ -467,6 +480,7 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->mutable_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; + AddToGlobalExceptions(threaded_opr->opr_exception); break; } } @@ -475,6 +489,18 @@ class ThreadedEngine : public Engine { static void OnCompleteStatic(Engine *engine, void *threaded_opr, const dmlc::Error* error); + /*! + * \brief find exception in global_exception_refs and add it if missing + * \param opr_exception the exception to be added to global_exception_refs + */ + inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) { + auto it = std::find(global_exception_refs_.begin(), + global_exception_refs_.end(), opr_exception); + if (it == global_exception_refs_.end()) { + global_exception_refs_.push_back(opr_exception); + } + return; + } /*! \brief append an operator to bulk */ inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars, @@ -542,6 +568,8 @@ class ThreadedEngine : public Engine { */ std::mutex finished_m_; std::condition_variable finished_cv_; + /*! \brief global exception refs, which are rethrown when WaitForAll is called */ + std::vector global_exception_refs_; /*! * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early diff --git a/src/resource.cc b/src/resource.cc index de24286ba535..cd6320d393b1 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -189,12 +189,14 @@ class ResourceManagerImpl : public ResourceManager { cpu_rand_->Seed(seed); cpu_parallel_rand_->Seed(seed); #if MXNET_USE_CUDA - gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() { - return new ResourceRandom(ctx, seed); - })->Seed(seed); - gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() { - return new ResourceParallelRandom(ctx, gpu_native_rand_copy_, seed); - })->Seed(seed); + if (ctx.dev_type == Context::kGPU) { + gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() { + return new ResourceRandom(ctx, seed); + })->Seed(seed); + gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() { + return new ResourceParallelRandom(ctx, gpu_native_rand_copy_, seed); + })->Seed(seed); + } #endif } diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py index e9e161d7f3b6..60799f821b8e 100644 --- a/tests/python/unittest/test_exc_handling.py +++ b/tests/python/unittest/test_exc_handling.py @@ -34,11 +34,11 @@ def imperative(exec_numpy=True): c.asnumpy() imperative(exec_numpy=False) - assert_raises(MXNetError, imperative, True) + assert_raises(MXNetError, imperative, exec_numpy=True) @with_seed() def test_exc_symbolic(): - def symbolic(exec_backward=True): + def symbolic(exec_backward=True, waitall=True): x = mx.sym.Variable('x') y = mx.sym.Variable('y') z = mx.sym.Variable('z') @@ -58,16 +58,25 @@ def symbolic(exec_backward=True): outputs = exec1.forward() if exec_backward: exec1.backward() - exec1.grad_arrays[0].asnumpy() + if waitall: + mx.nd.waitall() + else: + exec1.grad_arrays[0].asnumpy() else: - outputs[0].asnumpy() + if waitall: + mx.nd.waitall() + else: + outputs[0].asnumpy() - assert_raises(MXNetError, symbolic, False) - assert_raises(MXNetError, symbolic, True) + assert_raises(MXNetError, symbolic, exec_backward=False) + assert_raises(MXNetError, symbolic, exec_backward=True) + + assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True) + assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True) @with_seed() def test_exc_gluon(): - def gluon(exec_wait=True): + def gluon(exec_wait=True, waitall=False): model = nn.Sequential() model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False)) model.add(nn.Dropout(1)) @@ -77,46 +86,86 @@ def gluon(exec_wait=True): y = model(x) model.collect_params().initialize(ctx=[default_context()]) z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context())) - if exec_wait: + if waitall: + mx.nd.waitall() + elif exec_wait: z.wait_to_read() gluon(exec_wait=False) - assert_raises(MXNetError, gluon, True) + assert_raises(MXNetError, gluon, exec_wait=True) + + assert_raises(MXNetError, gluon, waitall=True) @with_seed() def test_exc_multiple_waits(): - caught = False - try: - a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) - a.wait_to_read() - except MXNetError: - caught = True - assert caught, "No exception thrown" - try: - b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) - b.wait_to_read() - except MXNetError: - caught = True - assert caught, "No exception thrown" + def multiple_waits(waitall=False): + # Test calling failed op followed by wait_to_read or waitall twice + # Intention is to test rethrow for multiple wait_to_reads and waitalls + # for vars with exceptions in same scope + caught = False + try: + a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + a.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall" + try: + b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + b.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall" + + multiple_waits(waitall=False) + multiple_waits(waitall=True) @with_seed() def test_exc_post_fail(): + def post_fail(waitall=False): + caught = False + try: + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + a.asnumpy() + except MXNetError: + caught = True + assert caught, "No exception thrown" + b.asnumpy() + post_fail(waitall=False) + post_fail(waitall=True) + +@with_seed() +def test_exc_mutable_var_fail(): + def mutable_var_check(waitall=False): + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + a = mx.nd.dot(a, a) + if waitall: + mx.nd.waitall() + else: + a.asnumpy() + assert_raises(MXNetError, mutable_var_check, waitall=False) + assert_raises(MXNetError, mutable_var_check, waitall=True) + +@with_seed() +def test_multiple_waitalls(): caught = False try: - a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) - a.asnumpy() + a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + mx.nd.waitall() except MXNetError: caught = True assert caught, "No exception thrown" - b.asnumpy() + mx.nd.waitall() + -@with_seed() -def test_exc_mutable_var_fail(): - def mutable_var_check(): - a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) - a = mx.nd.dot(a, a) - a.asnumpy() - assert_raises(MXNetError, mutable_var_check) if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f96a6aead58c..17618e414343 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7164,7 +7164,12 @@ def get_output_names_callback(name, arr): op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False) - op_exe.forward() + try: + op_exe.forward() + mx.nd.waitall() + except mx.base.MXNetError: + # skip errors since test is to check output names + pass for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name @@ -7210,7 +7215,12 @@ def get_output_names_callback(name, arr): op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True) - op_exe.forward() + try: + op_exe.forward() + mx.nd.waitall() + except mx.base.MXNetError: + # skip errors since test is to check all names + pass for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name