Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[bug] fix higher grad log #15120

Merged
merged 7 commits into from
Jun 20, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,68 +1090,84 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log,
unary_bwd<mshadow_op::log_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log(x)
// f'(x) = 1/x
// f''(x) = -1 * (f'(x) * f'(x))
auto gx = nnvm::NodeEntry{n};
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, gx}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dlogx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10,
unary_bwd<mshadow_op::log10_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log10
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log10(x)
// f'(x) = 1 / (log(10) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n, 0, 0};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't elemwise_div require two inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it does. However I guess, this thing is skipped in the test computation graph and hence we don't see the error.

Will fix it. However we should somehow find a way to test for the same.

{n->inputs[0]}, nullptr, &n);
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2,
unary_bwd<mshadow_op::log2_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log2
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log10(x)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
// f'(x) = 1 / (log(2) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
{n->inputs[0]}, nullptr, &n);
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

Expand Down
21 changes: 17 additions & 4 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,26 @@ def grad_grad_op(x):

def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
expect_grad_grad = grad_grad_op(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()

# Manual head_grads.
y_grad = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand what this variable is mathematically...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_grads is just the input node in the graph for x_grad.
head_grad_grads is just to check the validity of the chain rule/backprop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.


# Perform compute.
with autograd.record():
y = op(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
create_graph=True, retain_graph=True)[0]
x_grad.backward(head_grad_grads)

# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()

# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())


if __name__ == '__main__':
Expand Down