Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Second order gradient support for some unary operators #14613

Merged
merged 34 commits into from
Jun 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
45e1502
try to add support some ops
sxjscience Oct 14, 2018
904adb4
Merge branch 'higher_order_sample' of https://github.com/sxjscience/m…
apeforest Mar 7, 2019
d5dc994
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Mar 12, 2019
0e69075
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Mar 19, 2019
0c7cf98
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 2, 2019
492e4cd
add unit test for second order grad
apeforest Apr 3, 2019
45b334e
implement grad for relu and add unit test
apeforest Apr 3, 2019
3bbfbac
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 4, 2019
4dc0907
fix lint
apeforest Apr 5, 2019
c4034b2
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 5, 2019
3fe54e6
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 16, 2019
76aa6ad
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 16, 2019
8458717
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 21, 2019
f66610b
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 23, 2019
30ff1e9
register FGradient attribute for backward relu
apeforest May 28, 2019
8ecffcc
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 28, 2019
d9ba3da
resolve conflict
apeforest May 28, 2019
1c93c7d
remove unused imports
apeforest May 28, 2019
de721bc
change gradient using set_attr
apeforest May 30, 2019
0ac0942
remove higher order grad test for negative(x)
apeforest May 30, 2019
f8e624e
fix lint
apeforest May 30, 2019
3315124
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 30, 2019
8538980
reverse indent
apeforest May 30, 2019
1ee38b5
remove unused backward operator
apeforest May 30, 2019
c18f317
refactor backward for sin(x) and cos(x)
apeforest May 30, 2019
689cfee
change value init to list init
apeforest May 30, 2019
d56e132
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 30, 2019
2207815
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 31, 2019
0b6c2ef
change to list initialization
apeforest May 31, 2019
31f671f
generate random shape in test
apeforest May 31, 2019
62fcca3
fix a bug in second order backward
apeforest Jun 3, 2019
a0a0e75
fix lint
apeforest Jun 3, 2019
451c4bd
fix lint
apeforest Jun 4, 2019
b9b0c93
address reviewer comment and renaming
apeforest Jun 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,26 @@ The storage type of ``relu`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu,
unary_bwd<mshadow_op::relu_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, unary_bwd<mshadow_op::relu_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: y
// f(x) -> relu(x)
// f'(x) = 1 if x > 0 else 0
// f''(x) = 0
auto dydx = MakeNode("_greater", n->attrs.name + "_dydx",
{n->inputs[1], nnvm::NodeEntry{
MakeNode("zeros_like", n->attrs.name + "tmp", {n->inputs[1]}, nullptr, &n)
}}, nullptr, &n);
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry(dydx)}, nullptr, &n));
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_backward_grad_grad_in",
{n->inputs[1]}, nullptr, &n));
return ret;
});

// sigmoid
MXNET_OPERATOR_REGISTER_UNARY(sigmoid)
Expand Down
60 changes: 58 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,33 @@ The storage type of ``sin`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_sin" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: d^2L/dx^2
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = sin(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hidden operator so user do not see this.

// f'(x) = cos(x)
// f''(x) = -sin(x)
auto dydx = MakeNode("cos", n->attrs.name + "_dydx",
{n->inputs[1]}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{
MakeNode("sin", n->attrs.name + "_grad_grad_mid", {n->inputs[1]}, nullptr, &n)
}}, nullptr, &n);

auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_mid",
{n->inputs[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// cos
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(cos, cpu, mshadow_op::cos)
Expand All @@ -63,7 +89,37 @@ The storage type of ``cos`` output is always dense
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_cos"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cos, unary_bwd<mshadow_op::cos_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cos, unary_bwd<mshadow_op::cos_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: d^2L/dx^2
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = cos(x)
// f'(x) = -sin(x)
// f''(x) = -cos(x)
auto dydx = MakeNode("negative", n->attrs.name + "_dydx",
{nnvm::NodeEntry{
MakeNode("sin", n->attrs.name + "_grad_mid", {n->inputs[1]}, nullptr, &n)
}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{
MakeNode("cos", n->attrs.name + "_grad_grad_mid", {n->inputs[1]}, nullptr, &n)
}}, nullptr, &n);

auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_mid",
{n->inputs[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
// for the backward of the _backward_cos node
// first input is the ograd and second input is x (because ElemwiseUseIn)
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});


// tan
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(tan, cpu, mshadow_op::tan)
Expand Down
64 changes: 53 additions & 11 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,55 @@
# specific language governing permissions and limitations
# under the License.

import math

import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed


@with_seed()
def test_sin():
def sin(x):
return nd.sin(x)

def grad_grad_op(x):
return -nd.sin(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sin, grad_grad_op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think check_second_order_unary function should be moved to python/mxnet/test_utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only used in this test. If we add a different test file then it makes sense as you suggested.



@with_seed()
def test_cos():
def cos(x):
return nd.cos(x)

def grad_grad_op(x):
return -nd.cos(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cos, grad_grad_op)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these check_second_order_unary checks be changes to Nth order?

Copy link
Contributor Author

@apeforest apeforest Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is only to verify second order gradient. Can we add test for Nth order gradient in a separate PR?



@with_seed()
def test_relu():
def relu(x):
return nd.relu(x)

def grad_grad_op(x):
return nd.zeros_like(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, relu, grad_grad_op)


@with_seed()
def test_log():
def log(x):
Expand All @@ -30,9 +72,9 @@ def log(x):
def grad_grad_op(x):
return -1/(x**2)

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log, grad_grad_op)


Expand All @@ -44,9 +86,9 @@ def log2(x):
def grad_grad_op(x):
return -1/((x**2) * math.log(2))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log2, grad_grad_op)


Expand All @@ -58,9 +100,9 @@ def log10(x):
def grad_grad_op(x):
return -1/((x**2) * math.log(10))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log10, grad_grad_op)


Expand Down