-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
cdb6fc9
to
4415130
Compare
thanks! I can confirm that this fixes the swapaxes issue observed in Sockeye PR awslabs/sockeye#272 |
We need to test the case “grad_req=addto”. There seems to be many untested OPs. Here is an example of testing the correctness of AddTo: https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L777 |
const std::vector<TShape> &in_shape) const override { | ||
return {ResourceRequest::kTempSpace}; | ||
} | ||
|
||
Operator* CreateOperator(Context ctx) const override { | ||
LOG(FATAL) << "Not Implemented"; | ||
return NULL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also revise the test (https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L361-L377).
Temp space is not necessay. Change line https://github.com/apache/incubator-mxnet/pull/9495/files#diff-f299af08ae07ceaecbd468970ef7ee9cR128 to += when AddTo is set |
@@ -135,7 +135,7 @@ class SwapAxisOp : public Operator { | |||
const std::vector<TBlob> &aux_args) { | |||
using namespace mshadow; | |||
Stream<xpu> *s = ctx.get_stream<xpu>(); | |||
|
|||
CHECK_NE(req[swapaxisenum::kOut], kAddTo); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would probably make sense to add a message here to not get very cryptic error messages. Would it not be possible to apply what has been done for the backward pass?
here's the version of the fix suggested by @piiswrong: #9541 |
Will close this PR as #9541 is the suitable fix. |
Description
the operator switch_axis was ignoring the "req" argument. That leads to wrong gradients being computed if an operator fans out to two or more other ones where one of them is swap_axis.
This error was detected on something that should become a production model.
Checklist
Essentials
make lint
)