Skip to content

Commit

Permalink
optimize composite OP silu_double_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Feb 27, 2024
1 parent 7a97c10 commit 3df912d
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,15 +441,13 @@ void silu_double_grad(const Tensor& x,
Tensor* grad_x,
Tensor* grad_out_grad) {
auto sigmoid = 1 / (1 + exp<T>(-x));
auto tmp1 = 1 - sigmoid;
auto tmp2 = 1 + tmp1 * x;
auto ddx_mul_tt = grad_x_grad * (x - out + 1);
if (grad_out_grad) {
auto ddout = grad_x_grad * sigmoid * tmp2;
set_output<T>(ddout, grad_out_grad);
set_output<T>(ddx_mul_tt * sigmoid, grad_out_grad);
}
if (grad_x) {
auto dx = sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - out)) * tmp1;
set_output<T>(dx, grad_x);
auto sigmoid_g = sigmoid * (1 - sigmoid);
set_output<T>(ddx_mul_tt * sigmoid_g, grad_x);
}
}

Expand Down

0 comments on commit 3df912d

Please sign in to comment.