From 62d9e51914427bc734da70790f9a0b41e1ae9c76 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 28 Feb 2024 04:56:23 +0000 Subject: [PATCH] correct computation equation --- .../composite_backward/composite_double_backward_api.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index f2dfffceca320..4235425e75a8b 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -441,13 +441,15 @@ void silu_double_grad(const Tensor& x, Tensor* grad_x, Tensor* grad_out_grad) { auto sigmoid = 1 / (1 + exp(-x)); - auto ddx_mul_tt = grad_x_grad * (x - out + 1); + auto tmp = 1 + x - out; + auto ddx_mul_sigmoid = grad_x_grad * sigmoid; if (grad_out_grad) { - set_output(ddx_mul_tt * sigmoid, grad_out_grad); + set_output(ddx_mul_sigmoid * tmp, grad_out_grad); } if (grad_x) { auto sigmoid_g = sigmoid * (1 - sigmoid); - set_output(ddx_mul_tt * sigmoid_g, grad_x); + set_output(ddx_mul_sigmoid * out_grad * (1 - sigmoid) * (tmp - out + 1), + grad_x); } }