Skip to content

Commit

Permalink
fix gradient(nan) when two inputs are equal (#32448)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored Apr 25, 2021
1 parent 727b28d commit 1896c77
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions paddle/fluid/operators/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
auto sign =
(x_minux_y > static_cast<T>(0)).template cast<T>() * static_cast<T>(1.0) +
(x_minux_y < static_cast<T>(0)).template cast<T>() * static_cast<T>(-1.0);
T epsilon = static_cast<T>(1.0e-10f);

// 1: Lp-norm(z), z = x-y, compute dz
if (p == 0) {
Expand All @@ -189,12 +190,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
if (platform::is_cpu_place(context.GetPlace())) {
grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) *
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) =
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
out_grad_t.broadcast(out_bcast_dims);
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign * out_grad_t.broadcast(out_bcast_dims);
}
}

Expand Down

0 comments on commit 1896c77

Please sign in to comment.