Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Better Exception Handling for Operators #9681

Merged
merged 22 commits into from
Feb 13, 2018

Conversation

anirudh2290
Copy link
Member

@anirudh2290 anirudh2290 commented Feb 2, 2018

Description

Please see: https://cwiki.apache.org/confluence/display/MXNET/Improved+exception+handling+in+MXNet
Implements Exception Handling for Operators.

Fixes #7335 and related issues

Functional Testing

Performance Testing

I did a small performance testing task with resnet50 model on cifar10 dataset to make sure that there is no performance degradation because of the additional overhead of OnStart callback in each ExecuteOprBlock. I don't see any change.

Model: resnet50
Dataset: cifar10
Tested on ec2: p2.8xlarge

Mode Measurement Before Exception Handling Change After Exception Handling Change
Symbolic Speed 1064 samples/sec 1073 samples/sec
Imperative Speed 187 samples/sec 184 samples/sec

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Exception handling for Operators, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@cjolivier01 @mli @piiswrong @madjam @asmushetzel @eric-haibin-lin @reminisce @rahul003 @KellenSunderland @eftiquar

@KellenSunderland
Copy link
Contributor

This is going to be a big improvement for new users experimenting with the library. Thanks for the great work @anirudh2290.

#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: spacing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

@@ -182,7 +182,7 @@ class MXNET_API Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr, bool wait = false) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wait? Why do you need it? Please document arguments

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait is used to indicate to the ExecuteOprBlock whether it is a waitforvar operation. it should not block the execution of the operator for WaitForVar.

if (threaded_var->ex_ptr) {
std::rethrow_exception(threaded_var->ex_ptr);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you not returning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

@@ -391,6 +400,11 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
if (global_ex_ptr) {
std::exception_ptr ex_ptr = global_ex_ptr;
Copy link
Contributor

@piiswrong piiswrong Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::rethrow_exception(std::move(global_ex_ptr))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rolling this back, Since the state of global_ex_ptr is not guaranteed to be nullptr after the move and it depends on the implementation. This probably explains why it started failing on windows, after the change. Please let me know if you have any concerns.

i->ex_ptr = threaded_opr->ex_ptr;
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr;
}
bool debug_info = (engine_info_ && debug_wait_var_ == i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove const?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not intentional. Added it back.

if (threaded_opr->ex_ptr) {
i->ex_ptr = threaded_opr->ex_ptr;
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr;
}
Copy link
Contributor

@piiswrong piiswrong Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suppose and operator has three outputs x, y and z and it raises an exception.
then x.asnumpy() would raise an error.
Then y.asnumpy() would raise the same error again.

and if I do z += 1 and it succeeds, z.asnumpy() would still raise the error

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

z += 1 won't execute since z already has an exception associated with it. z.asnumpy() will still raise the error.

@@ -175,6 +175,8 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::exception_ptr ex_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex_ptr is a bad variable name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i changed ex_ptr for var to var_exception, ex_ptr for opr to opr_exception and the global exception_ptr global_exception_ . I welcome if you have any other suggestions for naming them.

@@ -338,33 +346,46 @@ class ThreadedEngine : public Engine {
#endif
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, opr_block);
CallbackOnComplete on_start_callback = this->CreateCallback(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of creating a call back here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback is not strictly necessary here, since it is called only once, but I included it to keep the ExecuteOprBlock easier to read and separate out the OnStart logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary overhead. Call OnStart directly if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Called OnStart directly.

} else {
callback();
}
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only catch dmlc error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I catch only dmlc::Error is that The guards in the c_api API_BEGIN and API_END/API_END_HANDLE_ERROR only catch dmlc::Error currently and propagate to frontend.

Copy link
Member Author

@anirudh2290 anirudh2290 Feb 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used dmlc::Error in the inner block and used std::exception to catch other stdlib exceptions thrown in the outer block. Currently, it will catch only dmlc::Error and for other exceptions(std::exception) the process will be terminated. I will open another PR to handle std::exception and change the c_api guards and frontend code.

@@ -182,7 +182,7 @@ class MXNET_API Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr, bool wait = false) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Add \param for wait in doc.
  2. Put bool wait... in the next line to keep the coding style consistent with the existing context. Same all the following changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed here and other places.

@@ -73,7 +73,7 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr, bool wait = false) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put in the next line.

@@ -125,7 +125,7 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr, bool wait = false) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next line.

inputs = [x, y]
out = mx.symbol.ElementWiseSum(*inputs, name="esum")
out = mx.sym.dot(z, out)
out2 = mx.sym.random_normal(0, -1, x_shape, ctx=default_context())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random.normal as it's the preferred way now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch(dmlc::Error &e) {
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would this catch now? There is already a try block inside

Copy link
Member Author

@anirudh2290 anirudh2290 Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong my intention for the outer block was to catch all other exceptions which are not caught in the inner block. I should change dmlc::Error to std::exception to catch all standard exception. But you make a good point about propagating all the exceptions and not just dmlc::Error to the frontend. We can take one of the two approaches here: 1. catch dmlc::Error and terminate the process for everything else. 2. catch both dmlc::Error and standard exceptions and propagate to frontend. This will require code changes to the guards , c_api_error and potentially frontend code.

@@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) throw std::bad_alloc();
if (ptr == NULL) LOG(FATAL) << "Malloc failure";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed to allocate CPU memory

@@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
throw std::bad_alloc();
LOG(FATAL) << cudaGetLastError();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does cudaGetLastError return a string or an error code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have changed it to call cudaGetErrorString

@@ -476,6 +507,9 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief exception_ptr associated with the engine,
* which is used to throw exception in waitall */
std::exception_ptr global_ex_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex_ptr is a bad name. class members should end with _.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed!

@piiswrong
Copy link
Contributor

suppose and operator has three outputs x, y and z and it raises an exception.
then x.asnumpy() would raise an error.
Then y.asnumpy() would raise the same error again.

I think an error should only be raised once. After it's raised, it should be cleared from all arrays that is pointing to that error.

This can be achieved by setting the object referenced by exception_ptr to an invalid value

@anirudh2290
Copy link
Member Author

@piiswrong Trying to understand your comment.

Lets say we have some code snippet like the below:

try:
       x, y, z = op()
       x.asnumpy()
except:
       handle_exc()
y = op2(y)
y.asnumpy()       

If we clear the exception_ptr corresponding to the var y when x.asnumpy() is executed, y may have some garbage value in it. op2 may end up executing fine, and after the last line y.asnumpy() we have no exception thrown. Shouldn't all the vars and ops which are in the chain following a failed op also fail ? Not doing this, will lead to the vars in the dependency chain following a failed op having non-deterministic garbage values depending on how the failed op and the following ops behave.

@KellenSunderland
Copy link
Contributor

KellenSunderland commented Feb 6, 2018

@anirudh2290 Good idea to use an example. From my perspective as a user I generally prefer that things fail as quickly as possible for me so that I don't have to track down root causes. Following this logic, I'd tend to agree with you, but would actually prefer it if this line threw an exception:

try:
       x, y, z = op()
       x.asnumpy()
except:
       handle_exc()
y = op2(y)  # Runtime exception
y.asnumpy()

I understand that this is a lazy operation, but wonder if it's still possible to do some failure validation here? If not the first blocking call (asnumpy) would also be a fairly intuitive place to throw.

@piiswrong
Copy link
Contributor

What I meant is, for example:

try:
       x, y, z = op()
       x.asnumpy()
except:
       handle_exc()
y.asnumpy()  # Fail

Currently y.asnumpy() will fail again with the same error as x.asnumpy().
But a single exception shouldn't be raised twice

@anirudh2290
Copy link
Member Author

try:
      x, y, z = op()
      x.asnumpy()
except:
      handle_exc()

'''
The below just pushes operation to the engine, no guarantee that op2 is executed (ExecuteOprBlock may not be called)
'''
y = op2(y)
y.asnumpy() # Guarantees that all the operations that are writing to y are executed

@KellenSunderland As you mentioned since it is a lazy operation, there is no guarantee that operation is executed, just that it is pushed to the engine. So, there is no guarantee that ExecuteOprBlock is called for the operator. On the other hand, it is guaranteed that all operations which write to a particular variable are executed when the blocking call on that variable is made. Therefore, I have rethrown exceptions in WaitForVar and WaitForAll. I understand that this may not be as intuitive to users as throwing on the y=op2(y) line itself, but I don't think it is possible to rethrow when it is pushed to engine.

@anirudh2290
Copy link
Member Author

try:
     x, y, z = op()
     x.asnumpy() #Throws exception, sets exception_ptr to nullptr
except:
     handle_exc()
y.asnumpy() #exception_ptr is nullptr, doesn't throw
y = op2(y)  
y.asnumpy() # y may have garbage values, op2 may execute just fine, exception_ptr still nullptr, doesn't throw ?

@piiswrong As depicted in the example above, if we decide to invalidate exception_ptr for y by setting it to nullptr when we WaitToRead x (I am unsure how we will do this), then we won't be propagating exceptions down the chain. Therefore, the last line here will execute just fine instead of throwing an exception, and user will end up with garbage values for y.

I understand your point that if an op has multiple write vars, and if we waited for one of the write vars and re-threw exception, we shouldn't throw it again for other vars. But, if we end up invalidating the exception_ptr, any continuing operators may or may not fail, and since the exception_ptr is invalidated we wouldn't be re-throwing the exception in any of the following WaitToReads.

@anirudh2290
Copy link
Member Author

After discussion with @piiswrong , we came to a conclusion that in an execution graph once an exception is thrown, the same exception should not be thrown again. For example:

x = None
y = None
try:
     x, y = op1() # Fails
     x.asnumpy() # Throw exception
except:
     handle_exc(x, y)
y.asnumpy() # Should execute fine
z = op2(x) # Should execute fine, expectation is that user modified the value of x from garbage, when handling exception

You can see that op1 throws an exception and it may end up writing garbage values to x and y. The line x.asnumpy() throws exception. Once this is done, user may handle the exception or keep the garbage values as it is. Any consequent usage of x or op2 should not throw the same exception, since user is not expecting and it is already handled.

One challenge during the implementation was that dereferencing exception_ptr in C++ will cause undefined behavior. So there is no way to modify state of the exception_object that exception_ptr points to, just by using exception_ptr itself. To workaround this limitation, we are holding the exception_ptr itself in a shared_ptr object.

We decided to remove global_exception thrown in WaitForAll, since it adds unnecessary complexity and is not really used much except during benchmarking.

@anirudh2290
Copy link
Member Author

@piiswrong: Do you have additional suggestions ?

@piiswrong piiswrong merged commit 7b24137 into apache:master Feb 13, 2018
larroy pushed a commit to larroy/mxnet that referenced this pull request Feb 21, 2018
* Add support for threaded engine

* Add support for threaded engine

* Remove on_start_callback for else

* Add support for global_ex_ptr

* Rethrow in waitall only once

* run tests for gpu

* Add comments for exception_ptr

* Fix lint

* Push exc_handling tests

* Add comments for OnStart

* Fixes for exc handling

* Catch std::exception for all other exceptions

* Rollback std::move use

* Fix style

* Fix onstart

* Fix debug_info

* Throw exception only once in an execution graph

* make test naming consistent

* Fix symbolic test

* Remove unused code
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Add support for threaded engine

* Add support for threaded engine

* Remove on_start_callback for else

* Add support for global_ex_ptr

* Rethrow in waitall only once

* run tests for gpu

* Add comments for exception_ptr

* Fix lint

* Push exc_handling tests

* Add comments for OnStart

* Fixes for exc handling

* Catch std::exception for all other exceptions

* Rollback std::move use

* Fix style

* Fix onstart

* Fix debug_info

* Throw exception only once in an execution graph

* make test naming consistent

* Fix symbolic test

* Remove unused code
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Add support for threaded engine

* Add support for threaded engine

* Remove on_start_callback for else

* Add support for global_ex_ptr

* Rethrow in waitall only once

* run tests for gpu

* Add comments for exception_ptr

* Fix lint

* Push exc_handling tests

* Add comments for OnStart

* Fixes for exc handling

* Catch std::exception for all other exceptions

* Rollback std::move use

* Fix style

* Fix onstart

* Fix debug_info

* Throw exception only once in an execution graph

* make test naming consistent

* Fix symbolic test

* Remove unused code
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
4 participants