-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Second order gradient support for some unary operators #14613
Changes from 25 commits
45e1502
904adb4
d5dc994
0e69075
0c7cf98
492e4cd
45b334e
3bbfbac
4dc0907
c4034b2
3fe54e6
76aa6ad
8458717
f66610b
30ff1e9
8ecffcc
d9ba3da
1c93c7d
de721bc
0ac0942
f8e624e
3315124
8538980
1ee38b5
c18f317
689cfee
d56e132
2207815
0b6c2ef
31f671f
62fcca3
a0a0e75
451c4bd
b9b0c93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,8 +85,20 @@ The storage type of ``relu`` output depends upon the input storage type: | |
)code" ADD_FILELINE) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, | ||
unary_bwd<mshadow_op::relu_grad>); | ||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, unary_bwd<mshadow_op::relu_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
std::vector<nnvm::NodeEntry> ret; | ||
// f(x) -> f = relu | ||
// f'(x) = 1 if x > 0 else 0 | ||
// f''(x) = 0 | ||
auto gx = nnvm::NodeEntry{n}; // f'(x) | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], gx}, nullptr, &n)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to what you have done below for
Explanation : There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_backward_grad_grad_in", | ||
{gx}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
// sigmoid | ||
MXNET_OPERATOR_REGISTER_UNARY(sigmoid) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,26 @@ The storage type of ``sin`` output depends upon the input storage type: | |
)code" ADD_FILELINE) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_sin" }); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>); | ||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// f(x) = sin(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a hidden operator so user do not see this. |
||
// f'(x) = cos(x) | ||
// f''(x) = -sin(x) | ||
auto grad_x = nnvm::NodeEntry(n); | ||
auto grad_grad_x_mid = MakeNode("sin", n->attrs.name + "_mid_grad_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto grad_grad_x = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry(grad_grad_x_mid)}, nullptr, &n); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. List initialisation. Otherwise LGTM. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. |
||
std::vector<nnvm::NodeEntry> ret; | ||
// for the backward of the _backward_sin node | ||
// first input is the ograd and second input is x (because ElemwiseUseIn) | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], grad_x}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in", | ||
{ograds[0], nnvm::NodeEntry(grad_grad_x)}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
// cos | ||
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(cos, cpu, mshadow_op::cos) | ||
|
@@ -63,7 +82,27 @@ The storage type of ``cos`` output is always dense | |
)code" ADD_FILELINE) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_cos"}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cos, unary_bwd<mshadow_op::cos_grad>); | ||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cos, unary_bwd<mshadow_op::cos_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// f(x) = cos(x) | ||
// f'(x) = -sin(x) | ||
// f''(x) = -cos(x) | ||
auto grad_x = nnvm::NodeEntry(n); | ||
auto grad_grad_x_mid = MakeNode("cos", n->attrs.name + "_mid_grad_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto grad_grad_x = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry(grad_grad_x_mid)}, nullptr, &n); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
std::vector<nnvm::NodeEntry> ret; | ||
// for the backward of the _backward_cos node | ||
// first input is the ograd and second input is x (because ElemwiseUseIn) | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], grad_x}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in", | ||
{ograds[0], nnvm::NodeEntry(grad_grad_x)}, nullptr, &n)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can move the shared_ptrs inside as you are not using them anymore. If CI passes it's ok I can refactor it myself later. I wouldn't do another CI round for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the NodeEntry constructor already doing the std::move of shared_ptr? I thought declaring a separate variable in this scope increases readability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The argument is by value, so it needs to be an rvalue if you want it moved two times, otherwise is copied, then moved. So is best:
Having a shared ptr by value then move into a class seem to be the accepted idiom nowadays. Since it allows both move and copy semantics, depending if the caller passes an rvalue or not. |
||
return ret; | ||
}); | ||
|
||
|
||
// tan | ||
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(tan, cpu, mshadow_op::tan) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,52 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import math | ||
|
||
import math | ||
from mxnet import nd, autograd | ||
from mxnet.test_utils import assert_almost_equal, random_arrays | ||
from common import with_seed | ||
|
||
|
||
@with_seed() | ||
def test_sin(): | ||
def sin(x): | ||
return nd.sin(x) | ||
|
||
def grad_grad_op(x): | ||
return -nd.sin(x) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
for array in arrays: | ||
check_second_order_unary(array, sin, grad_grad_op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's only used in this test. If we add a different test file then it makes sense as you suggested. |
||
|
||
|
||
@with_seed() | ||
def test_cos(): | ||
def cos(x): | ||
return nd.cos(x) | ||
|
||
def grad_grad_op(x): | ||
return -nd.cos(x) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
for array in arrays: | ||
check_second_order_unary(array, cos, grad_grad_op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is only to verify second order gradient. Can we add test for Nth order gradient in a separate PR? |
||
|
||
|
||
@with_seed() | ||
def test_relu(): | ||
def relu(x): | ||
return nd.relu(x) | ||
|
||
def grad_grad_op(x): | ||
return nd.zeros_like(x) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we test for 1-d arrays? not sure if it is needed here, but there is this to randomize the shape of an array - https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/test_utils.py#L418 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the suggestion. updated. |
||
for array in arrays: | ||
check_second_order_unary(array, relu, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log(): | ||
def log(x): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be having these comments? couldn't it be included as part of 'R"code'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hidden operator so user do not see this.